In [None]:
import pandas as pd
import regex as re

import ahocorasick

import torch
from transformers import T5Tokenizer, MT5ForConditionalGeneration

from tqdm import tqdm

In [None]:
SEP_TOKEN = '▁<extra_id_0>'
SEQ_MAX_LENGTH = 150

USED_MODEL_NAME = './coint_rut5small_finetune_fulltrain_novalid'

In [None]:
tokenizer = T5Tokenizer.from_pretrained (USED_MODEL_NAME)

device = torch.device ('cuda' if torch.cuda.is_available () else 'cpu')
model = MT5ForConditionalGeneration.from_pretrained (USED_MODEL_NAME)
model.to (device)
print (model.device)

In [None]:
def raw_splitter (text, delimiters = ['...', '.', '?!', '?', '!']):

    delimiters_pattern = '|'.join (map (re.escape, delimiters))
    paragraphs = re.split (f'(?<=\n[ ]*)', text)
    
    sentences_with_indices = []
    current_start_index = 0

    for paragraph in paragraphs:

        sentences = re.split (f'(?<=[{delimiters_pattern}] )(?=[A-ZА-ЯЁ])', paragraph)

        for sentence in sentences:

            start_index = current_start_index
            end_index = start_index + len (sentence)

            sentences_with_indices.append ((sentence, (start_index, end_index)))

            current_start_index = end_index
    
    return sentences_with_indices


LABEL_PREFIX_TOKEN = '▁<extra_id_1>'
def get_set (tensor, tokenizer = tokenizer):

    separator = SEP_TOKEN

    seq = tensor [tensor != 0]
    seq = seq [seq != 1]
    seq = seq [seq != - 100]
    
    txt = tokenizer.decode (seq)
    res = set ([item.strip () for item in txt.split (separator)])

    if len (res) > 1:
        res -= set ([''])
    res -= set ([LABEL_PREFIX_TOKEN[1:]])
    if len (res) == 0: res |= set ([''])

    return res



def one_finder (text, phrases):

    if len (phrases) == 1 and '' in phrases:
        return []

    A = ahocorasick.Automaton ()
    
    for idx, phrase in enumerate (phrases):
        A.add_word (phrase, (idx, phrase))
    
    A.make_automaton ()
    
    found = []
    for end_index, (idx, phrase) in A.iter (text):
        start_index = end_index - len (phrase) + 1

        if start_index > 0 and text [start_index - 1].isalpha ():
            continue
        if end_index + 1 < len (text) and text [end_index + 1].isalpha ():
            continue

        found.append ((start_index, end_index + 1, phrase))
    
    return found

In [None]:

def predict_with_model (texts, model = model, tokenizer = tokenizer):

    model.eval ()

    predictions = []

    for text in tqdm (texts):

        sentences_w_ind = raw_splitter (text)

        answers = []
        for sentence, (start, end) in sentences_w_ind:

            sentence_tokenized = tokenizer (sentence, padding = 'max_length', truncation = True, max_length = SEQ_MAX_LENGTH, return_tensors = 'pt')

            with torch.no_grad ():

                input_ids = sentence_tokenized ['input_ids'].to (model.device)
                attention_mask = sentence_tokenized ['attention_mask'].to (model.device)

                out = model.generate (input_ids = input_ids, attention_mask = attention_mask, max_length = SEQ_MAX_LENGTH)

                term_set = get_set (out [0])

            found = one_finder (sentence, term_set)
            answers += [[item [0] + start, item [1] + start] for item in found]

        predictions.append (answers)

    return (predictions)



def predict_with_model_effective (texts, model = model, tokenizer = tokenizer):

    model.eval ()

    predictions = []

    for text in tqdm (texts):

        sentences_w_ind = raw_splitter (text)
        sentences = [sent for sent, (_, _) in sentences_w_ind]

        batch_size = 8
        out = []
        if len (sentences) > batch_size:
            num_batches = (len (sentences) + batch_size - 1) // batch_size
            for i in range (num_batches):
                batch_sentences = sentences [i * batch_size: (i + 1) * batch_size]
                
                sentences_tokenized = tokenizer (batch_sentences, padding = 'max_length', truncation = True, max_length = SEQ_MAX_LENGTH, return_tensors = 'pt')
                
                input_ids = sentences_tokenized ['input_ids'].to (model.device)
                attention_mask = sentences_tokenized ['attention_mask'].to (model.device)
                
                output = model.generate (input_ids = input_ids, attention_mask = attention_mask, max_length = SEQ_MAX_LENGTH)
                
                out.extend ([item for item in output])
        
        else:
            sentences_tokenized = tokenizer (sentences, padding = 'max_length', truncation = True, max_length = SEQ_MAX_LENGTH, return_tensors = 'pt')
                
            input_ids = sentences_tokenized ['input_ids'].to (model.device)
            attention_mask = sentences_tokenized ['attention_mask'].to (model.device)

            out = model.generate (input_ids = input_ids, attention_mask = attention_mask, max_length = SEQ_MAX_LENGTH)
            out = [item for item in out]

        answers = []
        for i in range (len (out)):
            sentence = sentences_w_ind [i] [0]
            output = out [i]
            start = sentences_w_ind [i] [1] [0]
            term_set = get_set (output)

            found = one_finder (sentence, term_set)

            answers += [[item [0] + start, item [1] + start] for item in found]

        predictions.append (answers)

    return (predictions)

In [None]:
def label_constructor (labels):
    res = []
    for label in labels:
        one_label = []
        for start, end, cls in label:
            one_label.append ([start, end])
        res.append (one_label)
    return res

In [None]:
df = pd.read_json ('./test_data/test1_t12_full_v2.jsonl', lines = True)
df = df [['text', 'label']]
print (df.head ())

In [None]:
val_data_txt = df ['text']
val_labels_txt = df ['label']

In [None]:
print (val_data_txt)

In [None]:
def comparator (pred, labl):

    pred = set ([tuple (item) for item in pred])
    labl = set ([tuple (item) for item in labl])

    true_positives = len (pred & labl)
    false_positives = len (pred - labl)
    false_negatives = len (labl - pred)

    return true_positives, false_positives, false_negatives

def metricator (preds, labels):

    tps_sum = 0
    fps_sum = 0
    fns_sum = 0 

    for i in range (len (labels)):

        true_positives, false_positives, false_negatives = comparator (preds [i], labels [i])

        tps_sum += true_positives
        fps_sum += false_positives
        fns_sum += false_negatives

    precision = tps_sum / (tps_sum + fps_sum) if (tps_sum + fps_sum) > 0 else 0
    recall = tps_sum / (tps_sum + fns_sum) if (tps_sum + fns_sum) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return precision, recall, f1_score

In [None]:
val_data_lst = val_data_txt.tolist ()
val_labels_lst = label_constructor (val_labels_txt)

In [None]:
predictions = predict_with_model_effective (val_data_lst)

In [None]:
res = metricator (predictions, val_labels_lst)
res

In [None]:
df = pd.read_json ('./test2_t12_v2.jsonl', lines = True)
print (df.head ())

test_data_txt = df ['text']

test_data_lst = test_data_txt.tolist ()

In [None]:
predictions = predict_with_model (test_data_lst)
predictions

In [None]:
idx = 2
print (test_data_lst [idx])
for item in predictions [idx]:
    print (f'{test_data_lst [idx] [item [0]: item [1]]}', end = ', ')

In [None]:
df ['label'] = predictions

In [None]:
df.to_json ('res-digr-test2_t12_v2.jsonl', orient = 'records', lines = True, force_ascii = False)