In [1]:
import os
import sys
import random
import itertools
import pickle
import json

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from pandarallel import pandarallel

import pymorphy2
import nltk
from nltk.tokenize import word_tokenize, wordpunct_tokenize, sent_tokenize

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from keras.preprocessing.sequence import pad_sequences

from transformers import T5Model, T5Tokenizer, T5ForConditionalGeneration

sys.path.append("..")
from src import *

def init_random_seed(value=0):
    random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    torch.backends.cudnn.deterministic = True
init_random_seed()
    
pd.set_option('display.max_colwidth', 255)
tqdm.pandas()
pandarallel.initialize(progress_bar=True, nb_workers=14, use_memory_fs=False)

%load_ext autoreload
%autoreload 2

INFO: Pandarallel will run on 14 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [2]:
abbr = pd.read_csv("../data/abbr.csv")
lenta_train = pd.read_csv("../data/lenta_train_t5.csv")
lenta_test = pd.read_csv("../data/lenta_test_t5.csv")

In [3]:
train_texts = list(map(lambda x: x.split(" "), lenta_train["text_upd"].to_list()))
train_labels = list(map(lambda x: x.split(" "), lenta_train["labels_upd"].to_list()))

test_texts = list(map(lambda x: x.split(" "), lenta_test["text_upd"].to_list()))
test_labels = list(map(lambda x: x.split(" "), lenta_test["labels_upd"].to_list()))

In [4]:
def get_sentences(text, labels):
    sentences = sent_tokenize(" ".join(text), language="russian")
    sent_labels = []
    
    i = 0
    for sent in sentences:
        sent = sent.split(" ")
        sent_labels.append(" ".join(labels[i:i+len(sent)]))
        i += len(sent)
    return sentences, sent_labels

In [5]:
sentences = []
sentences_labels = []
for text, labels in tqdm(zip(train_texts, train_labels)):
    sent, sent_labels = get_sentences(text, labels)
    sentences.extend(sent)
    sentences_labels.extend(sent_labels)

403411it [00:43, 9229.81it/s]


In [16]:
def get_fill_tasks(sent, labels):
    tasks = []
    labels = labels.split(" ")
    for i in range(len(labels)):
        if labels[i] != "_":
            sent_list = sent.split(" ")
            desc = labels[i]
            abbr_norm = sent_list[i]
            sent_list[i] = "<extra_id_1>"
            task = f"fill {abbr_norm} | {' '.join(sent_list)}"
            tasks.append((task, desc))
    return tasks

In [17]:
fill_tasks = []
for sent_i, labels_i in tqdm(zip(sentences, sentences_labels)):
    fill_tasks.extend(get_fill_tasks(sent_i, labels_i))

4083077it [00:03, 1220538.24it/s]


In [18]:
fill_tasks[2320]

("fill сд | Также , по словам Фадеева , до конца недели правительство утвердит состав <extra_id_1> ОАО `` РЖД '' , после чего будет проведено первое заседание Совета директоров .",
 'Совета=директоров')

In [19]:
with open("../data/fill_tasks_t5.pickle", "wb") as f:
    pickle.dump(fill_tasks, f)

## Тестовая выборка 

In [20]:
sentences = []
sentences_labels = []
for text, labels in tqdm(zip(test_texts, test_labels)):
    sent, sent_labels = get_sentences(text, labels)
    sentences.extend(sent)
    sentences_labels.extend(sent_labels)

100853it [00:10, 9230.60it/s]


In [21]:
fill_tasks = []
for sent_i, labels_i in tqdm(zip(sentences, sentences_labels)):
    fill_tasks.extend(get_fill_tasks(sent_i, labels_i))

1021377it [00:00, 1166569.94it/s]


In [22]:
fill_tasks[2320]

('fill инфа | Ранее 5 марта в СМИ появилась <extra_id_1> о том , что Медведева выразила желание выступать за другую страну .',
 'информация')

In [23]:
with open("../data/fill_tasks_t5_test.pickle", "wb") as f:
    pickle.dump(fill_tasks, f)