In [None]:
import pandas as pd
import regex as re

from sklearn.model_selection import train_test_split

import ahocorasick

import torch
from transformers import AutoTokenizer, MT5ForConditionalGeneration, get_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

from tqdm import tqdm
import matplotlib.pyplot as plt

from copy import deepcopy

In [None]:
SEP_TOKEN = '▁<extra_id_0>'
INPUT_PREFIX = ''

In [None]:
df = pd.read_json ('train_t1_v1.jsonl', lines = True)
df.drop (columns = ['id', 'keywords'], inplace = True)
print (df.head ())

In [None]:
def split_text (text, segments, delimiters = ['...', '.', '?!', '?', '!']):

    delimiters_pattern = '|'.join (map (re.escape, delimiters))

    # Разделение текста на абзацы
    paragraphs = re.split (f'(?<=\n[ ]*)', text)
    
    sentences_with_segments = []
    
    current_start_index = 0

    for paragraph in paragraphs:
        # Разделение абзацев на предложения
        sentences = re.split (f'(?<=[{delimiters_pattern}] )(?=[A-ZА-ЯЁ])', paragraph)#.strip ())
        
        # Сопоставление предложений с сегментами разметки
        for sentence in sentences:

            start_index = current_start_index
            end_index = start_index + len (sentence)

            matched_segments = [
                text [start: end] for start, end in segments if start >= start_index and end <= end_index
            ]
            sentences_with_segments.append ((sentence, matched_segments))

            current_start_index = end_index
    
    return sentences_with_segments

In [None]:
#train_data_txt, temp_data_txt, train_labels_txt, temp_labels_txt = train_test_split (df ['text'], df ['label'], test_size = 0.2, random_state = 14)
#val_data_txt, test_data_txt, val_labels_txt, test_labels_txt = train_test_split (temp_data_txt, temp_labels_txt, test_size = 0.5, random_state = 14)

train_data_txt = df ['text']
train_labels_txt = df ['label']

In [None]:
def label_constructor (labels):
    res = []
    for label in labels:
        one_label = []
        for start, end, cls in label:
            one_label.append ([start, end])
        res.append (one_label)
    return res

In [None]:
df_2 = pd.read_json ('./test_data/test1_t12_full_v2.jsonl', lines = True)
df_2 = df_2 [['text', 'label']]
print (df_2.head ())

In [None]:
val_data_txt, test_data_txt, val_labels_txt, test_labels_txt = train_test_split (df_2 ['text'], df_2 ['label'], test_size = 0.5, random_state = 14)

#val_data_txt = df_2 ['text']
#val_labels_txt = df_2 ['label']

#test_data_txt = val_data_txt
#test_labels_txt = val_labels_txt

In [None]:
train_data_lst = train_data_txt.tolist ()
train_labels_lst = train_labels_txt.tolist ()

parallel_text = []
parallel_label = []

for i in range (len (train_data_lst)):
    text = train_data_lst [i]
    segments = train_labels_lst [i]
    splitted = split_text (text, segments)
    for sentence, terms in splitted:
        parallel_text.append (INPUT_PREFIX + sentence)
        constructed_label = ''
        for term in terms:
            constructed_label += term.strip () + SEP_TOKEN
        if len (constructed_label) > 0:
            parallel_label.append (constructed_label [: - len (SEP_TOKEN)])
        else:
            parallel_label.append ('')


val_data_lst = val_data_txt.tolist ()
val_labels_lst = label_constructor (val_labels_txt)
parallel_text_val = []
parallel_label_val = []
for i in range (len (val_data_lst)):
    text = val_data_lst [i]
    segments = val_labels_lst [i]
    splitted = split_text (text, segments)
    for sentence, terms in splitted:
        parallel_text_val.append (INPUT_PREFIX + sentence)
        constructed_label = ''
        for term in terms:
            constructed_label += term.strip () + SEP_TOKEN
        if len (constructed_label) > 0:
            parallel_label_val.append (constructed_label [: - len (SEP_TOKEN)])
        else:
            parallel_label_val.append ('')


test_data_lst = test_data_txt.tolist ()
test_labels_lst = label_constructor (test_labels_txt)
parallel_text_test = []
parallel_label_test = []
for i in range (len (test_data_lst)):
    text = test_data_lst [i]
    segments = test_labels_lst [i]
    splitted = split_text (text, segments)
    for sentence, terms in splitted:
        parallel_text_test.append (INPUT_PREFIX + sentence)
        constructed_label = ''
        for term in terms:
            constructed_label += term.strip () + SEP_TOKEN
        if len (constructed_label) > 0:
            parallel_label_test.append (constructed_label [: - len (SEP_TOKEN)])
        else:
            parallel_label_test.append ('')

In [None]:
USED_MODEL_NAME = 'cointegrated/rut5-small'

SEQ_MAX_LENGTH = 150
BATCH_SIZE = 4
EVAL_BATCH_SIZE = 16

ENABLE_LABEL_FIX = True

In [None]:
tokenizer = AutoTokenizer.from_pretrained (USED_MODEL_NAME)

In [None]:
train_data = tokenizer (parallel_text, padding = 'max_length', truncation = True, max_length = SEQ_MAX_LENGTH, return_tensors = 'pt')
train_labels = tokenizer (parallel_label, padding = 'max_length', truncation = True, max_length = SEQ_MAX_LENGTH, return_tensors = 'pt')

val_data = tokenizer (parallel_text_val, padding = 'max_length', truncation = True, max_length = SEQ_MAX_LENGTH, return_tensors = 'pt')
val_labels = tokenizer (parallel_label_val, padding = 'max_length', truncation = True, max_length = SEQ_MAX_LENGTH, return_tensors = 'pt')

test_data = tokenizer (parallel_text_test, padding = 'max_length', truncation = True, max_length = SEQ_MAX_LENGTH, return_tensors = 'pt')
test_labels = tokenizer (parallel_label_test, padding = 'max_length', truncation = True, max_length = SEQ_MAX_LENGTH, return_tensors = 'pt')

In [None]:
def replace_padding (labels):
    for i in range (0, len (labels ['input_ids'])):
        labels ['input_ids'] [i] = torch.tensor ([labl if labl != 0 else - 100 for labl in labels ['input_ids'] [i]])

    return labels

if ENABLE_LABEL_FIX:
    rse = replace_padding (train_labels)
    train_labels = rse

    rse = replace_padding (val_labels)
    val_labels = rse

    rse = replace_padding (test_labels)
    test_labels = rse

train_labels ['input_ids']

In [None]:
class Seq2SeqDataset (Dataset):
    def __init__ (self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__ (self, idx):
        return {
            'input_ids': self.encodings ['input_ids'] [idx],
            'attention_mask': self.encodings ['attention_mask'] [idx],
            'labels': self.labels ['input_ids'] [idx]
        }

    def __len__ (self):
        return len (self.encodings ['input_ids'])

train_dataset = Seq2SeqDataset (train_data, train_labels)
val_dataset = Seq2SeqDataset (val_data, val_labels)
test_dataset = Seq2SeqDataset (test_data, test_labels)

train_loader = DataLoader (train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader (val_dataset, batch_size = EVAL_BATCH_SIZE)
test_loader = DataLoader (test_dataset, batch_size = EVAL_BATCH_SIZE)

In [None]:
device = torch.device ('cuda' if torch.cuda.is_available () else 'cpu')

model = MT5ForConditionalGeneration.from_pretrained (USED_MODEL_NAME)

model.to (device)
print (model.device)


In [None]:
def get_set (tensor, ground_truth = True, tokenizer = tokenizer):

    separator: str
    if SEP_TOKEN == '▁<extra_id_0>': separator = '<extra_id_0>'
    else: separator = SEP_TOKEN

    res: set

    if ground_truth:
        eos_idx = (tensor == 1).nonzero ()
        if eos_idx.numel () > 0:
            eos_idx = int (eos_idx [0] [0])
        else:
            eos_idx = len (tensor)
        seq = tensor [:eos_idx]
    
    else:
        seq = tensor [tensor != 0]
        seq = seq [seq != - 100]
        seq = seq [seq != 1]
    
    txt = tokenizer.decode (seq)
    res = set ([item.strip () for item in txt.split (separator)])

    if len (res) > 1:
        res -= set ([''])

    return res


def sanity_check (preds, labels, to_print = False):

    tps_sum = 0
    fps_sum = 0
    fns_sum = 0 

    for i in range (len (labels)):
        predicted_set = get_set (preds [i], ground_truth = False)
        true_set = get_set (labels [i])

        if to_print: print (f'True: {true_set}\nPred: {predicted_set}')

        tps_sum += len (true_set & predicted_set)  # Истинно положительные
        fps_sum += len (predicted_set - true_set)     # Ложноположительные
        fns_sum += len (true_set - predicted_set)     # Ложноотрицательные

    precision = tps_sum / (tps_sum + fps_sum) if (tps_sum + fps_sum) > 0 else 0
    recall = tps_sum / (tps_sum + fns_sum) if (tps_sum + fns_sum) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return precision, recall, f1_score

In [None]:
metrics = {'train_loss': [], 'val_loss': [], 'precision': [], 'recall': [], 'f1_score': []}

In [None]:
num_epochs = 50
training_steps = len (train_data ['input_ids']) // BATCH_SIZE * num_epochs
warmup_steps = int (training_steps * 0.1)
print (f'Suggested train steps: {training_steps}\n\t warmup steps: {int (training_steps * 0.05)} - {int (training_steps * 0.1)}')

In [None]:
#from fairseq.optim.adafactor import Adafactor

optimizer = AdamW (filter (lambda p: p.requires_grad, model.parameters ()), lr = 1e-5, weight_decay = 0.01)

#optimizer = Adafactor (model.parameters (), lr = 3e-5, scale_parameter = False, relative_step = False, weight_decay = 0.02)

scheduler = get_scheduler ('linear', optimizer = optimizer, num_warmup_steps = warmup_steps, num_training_steps = training_steps)

In [None]:
prev_metric = 0

for epoch in range (num_epochs):
    model.train ()
    total_loss = 0
    
    for batch in tqdm (train_loader):
        
        input_ids = batch ['input_ids'].to (model.device)
        attention_mask = batch ['attention_mask'].to (model.device)
        labels = batch ['labels'].to (model.device)

        outputs = model (input_ids = input_ids, attention_mask = attention_mask, labels = labels)
        loss = outputs.loss
        total_loss += loss.item ()

        loss.backward ()
        optimizer.step ()
        optimizer.zero_grad ()

        scheduler.step ()

        torch.cuda.empty_cache ()
    

    avg_loss = total_loss / len (train_loader)

    metrics ['train_loss'].append ((epoch, avg_loss))

    model.eval ()
    val_preds, val_labels = [], []
    total_val_loss = 0

    with torch.no_grad ():
        for batch in val_loader:
            input_ids = batch ['input_ids'].to (model.device)
            attention_mask = batch ['attention_mask'].to (model.device)
            labels = batch ['labels'].to (model.device)

            outputs = model.generate (input_ids = input_ids, attention_mask = attention_mask, max_length = SEQ_MAX_LENGTH)

            total_val_loss += model (input_ids = input_ids, attention_mask = attention_mask, labels = labels).loss.item ()

            val_preds.extend (outputs)
            val_labels.extend (labels)

            torch.cuda.empty_cache ()

    avg_val_loss = total_val_loss / len (val_loader)

    prec, recl, f1sc = sanity_check (val_preds, val_labels)

    metrics ['val_loss'].append ((epoch, avg_val_loss))
    metrics ['precision'].append ((epoch, prec))
    metrics ['recall'].append ((epoch, recl))
    metrics ['f1_score'].append ((epoch, f1sc))

    if prev_metric > avg_val_loss:
        prev_model = deepcopy (model.state_dict ())
        prev_optimizer = deepcopy (optimizer.state_dict ())
        prev_metric = f1sc

    print (f'Epoch {epoch + 1} / {num_epochs}, Loss: {avg_loss:.4f}, Validation loss: {avg_val_loss:.4f}, {prec} / {recl} / {f1sc}')

In [None]:
print (metrics)

In [None]:
for i, (name, history) in enumerate (sorted (metrics.items ())):
    plt.figure (figsize = (10, 4))
    plt.title (name)
    plt.plot (*zip (*history))
    plt.grid ()
    plt.show ()

In [None]:


model.eval ()
val_preds, val_labels = [], []

with torch.no_grad ():
    for batch in tqdm (val_loader):
        input_ids = batch ['input_ids'].to (model.device)
        attention_mask = batch ['attention_mask'].to (model.device)
        labels = batch ['labels'].to (model.device)
        out = model.generate (input_ids = input_ids, attention_mask = attention_mask, max_length=170)

        val_preds.extend (out)
        val_labels.extend (labels)

        torch.cuda.empty_cache ()

    print ('Validation: ', sanity_check (val_preds, val_labels))


val_preds, val_labels = [], []

with torch.no_grad ():
    for batch in tqdm (test_loader):
        input_ids = batch ['input_ids'].to (model.device)
        attention_mask = batch ['attention_mask'].to (model.device)
        labels = batch ['labels'].to (model.device)
        out = model.generate (input_ids = input_ids, attention_mask = attention_mask, max_length=170)

        val_preds.extend (out)
        val_labels.extend (labels)

        torch.cuda.empty_cache ()

    print ('Test: ', sanity_check (val_preds, val_labels))

In [None]:
model.load_state_dict (prev_model)
optimizer.load_state_dict (prev_optimizer)

In [None]:

val_preds, val_labels = [], []
with torch.no_grad ():
    for batch in tqdm (val_loader):
        input_ids = batch ['input_ids'].to (model.device)
        attention_mask = batch ['attention_mask'].to (model.device)
        labels = batch ['labels'].to (model.device)
        out = model.generate (input_ids = input_ids, attention_mask = attention_mask, max_length=170)

        val_preds.extend (out)
        val_labels.extend (labels)

        torch.cuda.empty_cache ()

    print ('Validation wmax: ', sanity_check (val_preds, val_labels))


val_preds, val_labels = [], []
with torch.no_grad ():
    for batch in tqdm (test_loader):
        input_ids = batch ['input_ids'].to (model.device)
        attention_mask = batch ['attention_mask'].to (model.device)
        labels = batch ['labels'].to (model.device)
        out = model.generate (input_ids = input_ids, attention_mask = attention_mask, max_length=170)

        val_preds.extend (out)
        val_labels.extend (labels)

        torch.cuda.empty_cache ()

    print ('Test wmax: ', sanity_check (val_preds, val_labels))

In [None]:
#model.save_pretrained('./coint_rut5small_finetune_fulltrain_novalid')
#tokenizer.save_pretrained('./coint_rut5small_finetune_fulltrain_novalid')