In [1]:
import sys
sys.path.insert(1, "./src")

In [16]:
import pickle
from pathlib import Path

import yaml
from transformers import set_seed

from anonymization.gpt2_generation import GPT2GenerationAnonymization
from mask.util import convert_masked_docs_to_segments_set
from models.gpt2_model import PretrainedGPT2TextInfilling
from utils.infill_metrics import Statistics
from utils.log_reader import TensorBoardReader
from datasets.ner_dataset import get_ner_dataset

In [3]:
set_seed(42)

In [4]:
roc_stories_data_config = yaml.load(open("configs/roc_stories_data_config.yaml", 'r'), Loader=yaml.Loader)
i2b2_2006_data_config = yaml.load(open("configs/i2b2-2006_data_config.yaml", 'r'), Loader=yaml.Loader)
i2b2_2014_data_config = yaml.load(open("configs/i2b2-2014_data_config.yaml", 'r'), Loader=yaml.Loader)

In [5]:
# Anonymization config initialisation
anon_gpt2_config = yaml.load(open("configs/gpt2_anonymization_config.yaml", 'r'), Loader=yaml.Loader)

In [6]:
# Инициализация алгоритма генерации искусственных сущностей с помощью GPT2
model_reader = TensorBoardReader(Path(anon_gpt2_config["log_dir"]) / Path("lightning_logs"))
path_to_checkpoint = model_reader.get_ckpt_path(anon_gpt2_config["model_version"])
text_infill_model = PretrainedGPT2TextInfilling.load_from_checkpoint(path_to_checkpoint).to("cuda:0")

anonymization = GPT2GenerationAnonymization(text_infill_model, **anon_gpt2_config)

#### Проверка заполнения пропусков на случайным образом замаскированных текстах (roc stories)

In [7]:
path_to_data = roc_stories_data_config["train_data_path"]
split = "train"
if Path(path_to_data).suffix != '.pkl':
    path_to_data = str(Path(path_to_data).parent / Path(f'{Path(path_to_data).stem}_{split}.pkl'))
    
# категории сущностей в формате [список категорий отрезков в документе, ...]; исходный текст в формате [список отрезков в документе, ...]
with open(path_to_data, 'rb') as f:
    # [(текст документа, список наборов масок для него: [[(тип, сдвиг, длина), ...], ...]), ...]
    infill_dataset = pickle.load(f)
    categories_list, source_texts = convert_masked_docs_to_segments_set(infill_dataset)

In [8]:
infill_stats = Statistics(anonymization, categories_list[:1000], categories_list[:1000], source_texts[:1000], is_uncased=True)

Start data tokenization


100%|██████████| 1000/1000 [00:00<00:00, 4783.30it/s]
100%|██████████| 42/42 [00:46<00:00,  1.11s/it]


In [9]:
indexes = infill_stats.most_close_examples_indexes(30)
infill_stats.print_examples_by_indexes(indexes[-5:])

_____ Record 577 _____
| Labels:           | NGRAM                                  | O                                                              | NGRAM  | O               | NGRAM | O                                                                                     | NGRAM |
| Source text:      | Luis had a present from a year earlier | . It was from his mom. He refused to open it because he wanted | her to | see him open it | .     | A few days later she stopped by. They were reunited and he finally opened the present | .     |
| Substituted text: | jim wanted to                          | . It was from his mom. He refused to open it because he wanted | to     | see him open it | .     | A few days later she stopped by. They were reunited and he finally opened the present | .     |
| CER               | 0.8157894611358643                     |                                                                | 0.5    |                 | 0.0   |                                      

#### Проверка заполнения пропусков в примерах с личной информацией (i2b2_2006)

In [17]:
# Для кэширования данных
get_ner_dataset(path_to_folder=i2b2_2006_data_config["train_data_path"], **i2b2_2006_data_config)

Token indices sequence length is longer than the specified maximum sequence length for this model (526 > 512). Running this sequence through the model will result in indexing errors


<datasets.ner_dataset.I2b2SixNerDataset at 0x7f21572be0b0>

In [18]:
path = str(Path(i2b2_2006_data_config["train_data_path"]).with_suffix(".pkl"))
with open(path, 'rb') as f:
    (_, source_texts, specific_category_list, general_category_list, _) = pickle.load(f)

In [19]:
infill_stats = Statistics(anonymization, general_category_list, specific_category_list, source_texts, is_uncased=True)

Start data tokenization


100%|██████████| 669/669 [00:02<00:00, 238.94it/s]
100%|██████████| 174/174 [03:49<00:00,  1.32s/it]


In [20]:
indexes = infill_stats.most_close_examples_indexes(30)
infill_stats.print_examples_by_indexes(indexes[-5:])

_____ Record 274 _____
| Labels:           | ID        | HOSPITAL | ID      | ID    | DATE  | O                         | PATIENT              | O     | ID                 | O                   | DATE  | O                                         | PATIENT              | O                                      | HOSPITAL                              | O  | DATE               | O                                                                                                                       | HOSPITAL                               | O                                           | DOCTOR             | O                                              | PHONE        | O              |
| Source text:      | 139519613 | tgcho    | 6214129 | 12892 | 04/11 | /1999 12:00:00 am         | kote , lyfranklapalm | mrn : | 6214129            | age :               | 04/11 | /1999 03:52 pm                            | kote , lyfranklapalm | arrived in the emergency department at | tecal galecounxopt com