In [None]:
from tqdm import tqdm
import random
import re
import os
from google.colab import drive
import json
import pandas as pd

drive.mount('/content/drive')

folder_path = '/content/drive/My Drive/HSE/programming/DL in NLP/Project'

for file in os.listdir(folder_path):
    print(file)

Mounted at /content/drive
synonims.csv
RuDReC NER.gslides
clinical_corpus
BIO_data.csv
annotation
sentence_data_splitted
label2id.json
id2label.json
model
augmentations
ner_sent_results
models_ner_sent
NER+sent_classification.ipynb
models_ner
NER.ipynb
errors_analysis.ipynb
RuDrec_project.gdoc


In [None]:
!wget "https://raw.githubusercontent.com/cimm-kzn/RuDReC/master/data/rudrec_annotated.json"

--2024-03-24 17:55:00--  https://raw.githubusercontent.com/cimm-kzn/RuDReC/master/data/rudrec_annotated.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1773014 (1.7M) [text/plain]
Saving to: ‘rudrec_annotated.json’


2024-03-24 17:55:01 (9.38 MB/s) - ‘rudrec_annotated.json’ saved [1773014/1773014]



In [None]:
datapath = './rudrec_annotated.json'

all_lines = []
with open(datapath) as f:
    for line in f:
        file_ = json.loads(line)
        all_lines.append(file_)

In [None]:
data = pd.DataFrame(all_lines)
data.head()

Unnamed: 0,file_name,text,entities,sentence_id
0,172744.tsv,"нам прописали, так мой ребенок сыпью покрылся,...","[{'start': 122, 'entity_type': 'Drugform', 'en...",0
1,172744.tsv,Общее впечатление : не подошел\n,[],1
2,592814.tsv,Пила этот препарат для повышения иммунитета 5 ...,"[{'start': 23, 'entity_type': 'DI', 'end': 43,...",0
3,592814.tsv,"Так как начала работать в аптеке, начала часто...",[],1
4,592814.tsv,В месяц по нескольку раз причем со всеми вытек...,"[{'start': 66, 'entity_type': 'DI', 'end': 72,...",2


## Синонимы

In [None]:
synonyms_df = pd.read_csv(folder_path + '/synonims.csv', sep='\t', index_col=0)
synonyms_df.head()

Unnamed: 0,concept_id,entity_text,entity_type
0,C0000731,в животе было вздутие,ADR
1,C0000731,вздутие живота,ADR
2,C0001047,ацц,Drugname
3,C0001367,ацикловир сандоз,Drugname
4,C0001367,ацикловира,Drugname


In [None]:
def augment_text_and_update_indices(df, synonyms_df):
    augmented_data = []
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        text = row['text']
        updated_entities = []
        offset = 0  # Смещение, вызванное заменой текста

        for entity in row['entities']:
            if 'concept_id' in entity:
                concept_id = entity['concept_id']
                synonyms = synonyms_df[synonyms_df['concept_id'] == concept_id]['entity_text'].tolist()
                if synonyms:
                    synonym = random.choice(synonyms)
                    original_text_length = entity['end'] - entity['start']
                    new_text_length = len(synonym)
                    text = text[:entity['start'] + offset] + synonym + text[entity['end'] + offset:]
                    updated_entities.append({
                        "start": entity['start'] + offset,
                        "entity_type": entity['entity_type'],
                        "end": entity['start'] + offset + new_text_length,
                        "entity_id": entity['entity_id'],
                        "entity_text": synonym,
                        "concept_id": entity['concept_id'],
                        "concept_name": entity.get('concept_name')
                    })
                    offset += new_text_length - original_text_length  # Обновляем смещение для следующих сущностей

        augmented_data.append({
            "file_name": row['file_name'],
            "text": text,
            "entities": updated_entities,
            "sentence_id": row['sentence_id']
        })
    return augmented_data

In [None]:
augmented_data = augment_text_and_update_indices(data, synonyms_df)

100%|██████████| 4809/4809 [00:02<00:00, 1604.78it/s]


In [None]:
augmented_data[:3]

[{'file_name': '172744.tsv',
  'text': 'нам прописали, так мой ребенок сыпь уже была на груди, глаза опухли, сверху и снизу высыпания, ( 8 месяцев сыну)А от виферону такого не было... У кого ещё такие побочки, отзовитесь!1 Чем спасались?\n',
  'entities': [{'start': 122,
    'entity_type': 'Drugform',
    'end': 130,
    'entity_id': '*[0]_se',
    'entity_text': 'виферону',
    'concept_id': 'C0021735',
    'concept_name': nan},
   {'start': 31,
    'entity_type': 'ADR',
    'end': 53,
    'entity_id': '*[1]',
    'entity_text': 'сыпь уже была на груди',
    'concept_id': 'C0015230',
    'concept_name': nan},
   {'start': 55,
    'entity_type': 'ADR',
    'end': 67,
    'entity_id': '*[2]',
    'entity_text': 'глаза опухли',
    'concept_id': 'C4760994',
    'concept_name': nan},
   {'start': 84,
    'entity_type': 'ADR',
    'end': 93,
    'entity_id': '*[3]',
    'entity_text': 'высыпания',
    'concept_id': 'C0015230',
    'concept_name': nan}],
  'sentence_id': 0},
 {'file_name': 

In [None]:
file_path = 'augmented_synonyms.json'

with open(file_path, 'w', encoding='utf-8') as f:
    for item in augmented_data:
        json.dump(item, f, ensure_ascii=False)
        f.write('\n')

In [None]:
umls_df = pd.read_csv('umls_concepts.csv')
augmented_data = augment_text_and_update_indices(data, umls_df)

100%|██████████| 4809/4809 [00:03<00:00, 1418.80it/s]


In [None]:
augmented_data[:3]

[{'file_name': '172744.tsv',
  'text': 'нам прописали, так мой ребенок экзантема, опухшие глаза, сверху и снизу экзантема, ( 8 месяцев сыну)А от виферона такого не было... У кого ещё такие побочки, отзовитесь!1 Чем спасались?\n',
  'entities': [{'start': 31,
    'entity_type': 'ADR',
    'end': 40,
    'entity_id': '*[1]',
    'entity_text': 'экзантема',
    'concept_id': 'C0015230',
    'concept_name': nan},
   {'start': 42,
    'entity_type': 'ADR',
    'end': 55,
    'entity_id': '*[2]',
    'entity_text': 'опухшие глаза',
    'concept_id': 'C4760994',
    'concept_name': nan},
   {'start': 72,
    'entity_type': 'ADR',
    'end': 81,
    'entity_id': '*[3]',
    'entity_text': 'экзантема',
    'concept_id': 'C0015230',
    'concept_name': nan}],
  'sentence_id': 0},
 {'file_name': '172744.tsv',
  'text': 'Общее впечатление : не подошел\n',
  'entities': [],
  'sentence_id': 1},
 {'file_name': '592814.tsv',
  'text': 'Пила этот препарат для бустерная доза вакцины 5 лет назад.\n',
  

In [None]:
file_path = 'augmented_umls.json'

with open(file_path, 'w', encoding='utf-8') as f:
    for item in augmented_data:
        json.dump(item, f, ensure_ascii=False)
        f.write('\n')

## UMLS

In [None]:
!pip install transliterate

Collecting transliterate
  Downloading transliterate-1.10.2-py2.py3-none-any.whl (45 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/45.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.8/45.8 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: transliterate
Successfully installed transliterate-1.10.2


In [None]:
import re
from transliterate import translit
import requests

def improve_transliteration(text):
    text = text.replace('тс', 'ц')
    text = text.replace('цх', 'ч')
    text = text.replace('иа', 'я')
    return text

def is_cyrillic(text):
    return bool(re.search('[\u0400-\u04FF]', text))

def ensure_cyrillic(name):
    if is_cyrillic(name):
        return name.lower()
    else:
        return improve_transliteration(translit(name, 'ru').lower())

def get_names_for_concept(apikey, version, concept_id):
    names = []
    content_endpoint = f'https://uts-ws.nlm.nih.gov/rest/content/{version}/CUI/{concept_id}'
    query = {'apiKey': apikey, 'language': 'RUS'}
    response = requests.get(content_endpoint, params=query)
    response.encoding = 'utf-8'

    # Проверка статуса ответа и вывод сообщения в случае ошибки
    if response.status_code != 200:
        print(f"Concept ID '{concept_id}' not found in UMLS.")
        return names  # Возвращаем пустой список, если концепт не найден

    items = response.json()
    jsonData = items.get('result', {})

    # Проверка наличия атомов в ответе
    if 'atoms' in jsonData and jsonData['atoms']:
        Atoms = jsonData['atoms']
        page = 0

        try:
            while True:
                page += 1
                atom_query = {'apiKey': apikey, 'pageNumber': page}
                a = requests.get(Atoms, params=atom_query)
                a.encoding = 'utf-8'

                if a.status_code != 200:
                    break

                all_atoms = a.json()
                jsonAtoms = all_atoms.get('result', [])

                for atom in jsonAtoms:
                    if atom['language'] == 'RUS':
                        name = ensure_cyrillic(atom['name'])
                        if name not in names:  # Проверка на уникальность
                            names.append(name)
        except Exception as except_error:
            print(f"Error processing concept ID '{concept_id}': {except_error}")
    else:
        print(f"No Russian atoms found for concept ID '{concept_id}'.")

    return names

In [None]:
apikey = 'a7dc2f34-3b9f-440e-9ecf-08dce62cc067'
version = 'current'
identifier = 'C0020971'  # Пример CUI
get_names_for_concept(apikey, version, identifier)

['иммунизация',
 'вакцинация',
 'иммунологическая стимуляция',
 'иммуностимуляция',
 'сенсибилизация иммунологическая',
 'иммунизации',
 'бустерная доза вакцины']

In [None]:
# Извлечение уникальных concept_id
unique_concept_ids = set()
for row in data['entities']:
    for entity in row:
        if 'concept_id' in entity:
            unique_concept_ids.add(entity['concept_id'])

In [None]:
import pandas as pd
from tqdm import tqdm

rows = []
for concept_id in tqdm(unique_concept_ids):
    names = get_names_for_concept(apikey, version, concept_id)
    for name in names:
        rows.append({'concept_id': concept_id, 'entity_text': name.lower()})

# Удаление дубликатов
unique_rows = [dict(t) for t in {tuple(d.items()) for d in rows}]

new_df = pd.DataFrame(unique_rows, columns=['concept_id', 'entity_text'])

  4%|▍         | 21/530 [01:12<24:22,  2.87s/it]

Concept ID 'C0010200C0010200' not found in UMLS.


  8%|▊         | 41/530 [02:19<26:44,  3.28s/it]

Concept ID 'None' not found in UMLS.


 10%|█         | 54/530 [03:01<25:14,  3.18s/it]

Concept ID '?чего' not found in UMLS.


 10%|█         | 55/530 [03:01<18:29,  2.34s/it]

Concept ID 'C0030193|?' not found in UMLS.


 12%|█▏        | 64/530 [03:22<14:23,  1.85s/it]

Concept ID 'C0443158|C1272745' not found in UMLS.


 16%|█▌        | 85/530 [04:30<11:01,  1.49s/it]

Concept ID 'C0003862|C0231528' not found in UMLS.


 18%|█▊        | 94/530 [04:58<19:52,  2.73s/it]

Concept ID '?C1971624' not found in UMLS.


 19%|█▉        | 100/530 [05:15<15:06,  2.11s/it]

Concept ID 'C0015230|C0239521' not found in UMLS.


 21%|██        | 109/530 [05:38<12:08,  1.73s/it]

Concept ID 'C0019159|C0085293' not found in UMLS.


 21%|██        | 111/530 [05:44<14:28,  2.07s/it]

Concept ID 'C0085281|C0439857' not found in UMLS.


 23%|██▎       | 124/530 [06:31<16:21,  2.42s/it]

Concept ID 'C0038435|C0678683' not found in UMLS.


 25%|██▍       | 132/530 [06:58<15:41,  2.36s/it]

Concept ID 'C0021400|C0029341' not found in UMLS.


 26%|██▌       | 136/530 [07:03<09:08,  1.39s/it]

Concept ID 'C0031350|C0877467' not found in UMLS.


 27%|██▋       | 143/530 [07:17<10:20,  1.60s/it]

Concept ID 'C0242429|C0240564' not found in UMLS.


 28%|██▊       | 148/530 [07:25<08:48,  1.38s/it]

Concept ID 'C0015230|C0497365' not found in UMLS.


 28%|██▊       | 151/530 [07:29<08:32,  1.35s/it]

Concept ID 'C0035242|C0155839' not found in UMLS.


 31%|███       | 163/530 [08:09<20:36,  3.37s/it]

Concept ID 'C1272745|C0025260' not found in UMLS.


 34%|███▍      | 181/530 [09:00<17:29,  3.01s/it]

Concept ID 'C0281856|C0423673' not found in UMLS.


 39%|███▉      | 206/530 [10:25<17:18,  3.21s/it]

Concept ID 'C0520559|' not found in UMLS.


 43%|████▎     | 228/530 [11:35<11:51,  2.36s/it]

Concept ID 'C0151825|C0003862' not found in UMLS.


 51%|█████     | 270/530 [13:38<10:05,  2.33s/it]

Concept ID 'C0038435|C0023670' not found in UMLS.


 53%|█████▎    | 282/530 [14:18<12:19,  2.98s/it]

Concept ID '?C0700184' not found in UMLS.


 54%|█████▍    | 288/530 [14:33<08:01,  1.99s/it]

Concept ID 'Piracetam' not found in UMLS.


 55%|█████▍    | 291/530 [14:38<06:30,  1.63s/it]

Concept ID 'C2896443|C2016977' not found in UMLS.


 59%|█████▉    | 315/530 [15:54<07:27,  2.08s/it]

Concept ID '?C0861172|C0858620' not found in UMLS.


 60%|█████▉    | 316/530 [15:54<05:35,  1.57s/it]

Concept ID '?/C0235568' not found in UMLS.


 61%|██████    | 322/530 [16:07<05:05,  1.47s/it]

Concept ID 'C0032285|C0028778' not found in UMLS.


 62%|██████▏   | 326/530 [16:21<07:31,  2.21s/it]

Concept ID 'C0240577|C0013604' not found in UMLS.


 62%|██████▏   | 329/530 [16:33<09:37,  2.87s/it]

Concept ID '?C3163857' not found in UMLS.


 63%|██████▎   | 336/530 [16:50<06:15,  1.93s/it]

Concept ID 'nan' not found in UMLS.


 64%|██████▍   | 339/530 [16:59<06:47,  2.13s/it]

Concept ID 'C3887661|C0010200' not found in UMLS.


 64%|██████▍   | 340/530 [16:59<05:04,  1.60s/it]

Concept ID '?C0015300' not found in UMLS.


 64%|██████▍   | 341/530 [17:00<03:52,  1.23s/it]

Concept ID 'C0003862|C3805216|C0239589' not found in UMLS.


 67%|██████▋   | 353/530 [17:38<07:14,  2.45s/it]

Concept ID 'C0239521|C0423777' not found in UMLS.


 67%|██████▋   | 355/530 [17:40<04:35,  1.57s/it]

Concept ID 'C1272745|C0025962' not found in UMLS.


 69%|██████▉   | 367/530 [18:29<11:11,  4.12s/it]

Concept ID 'C0070543)' not found in UMLS.


 72%|███████▏  | 381/530 [19:24<08:29,  3.42s/it]

Concept ID '?C0235169' not found in UMLS.


 73%|███████▎  | 389/530 [19:51<05:33,  2.37s/it]

Concept ID 'C0741585|C0003862' not found in UMLS.


 78%|███████▊  | 411/530 [21:13<05:57,  3.01s/it]

Concept ID 'C0033774|C0475858' not found in UMLS.


 79%|███████▊  | 417/530 [21:30<04:23,  2.33s/it]

Concept ID '?' not found in UMLS.


 82%|████████▏ | 432/530 [22:15<03:06,  1.90s/it]

Concept ID 'C1272745|C0005775' not found in UMLS.


 93%|█████████▎| 492/530 [25:15<01:29,  2.36s/it]

Concept ID 'C1272745|C0037817' not found in UMLS.


 95%|█████████▌| 505/530 [25:49<00:57,  2.28s/it]

Concept ID 'C1272745|C0037313' not found in UMLS.


 99%|█████████▉| 527/530 [26:59<00:06,  2.17s/it]

Concept ID 'C0003419|C3178748' not found in UMLS.


100%|██████████| 530/530 [27:10<00:00,  3.08s/it]


In [None]:
new_df.head()

Unnamed: 0,concept_id,entity_text
0,C0043352,сухая ротовая полость
1,C3887612,беспокойство
2,C0021051,иммунологической недостаточности синдромы
3,C0025517,нарушение обмена веществ
4,C0020281,перекись водорода


In [None]:
len(new_df)

1579

In [None]:
new_df.to_csv('umls_concepts.csv', index=False, encoding='utf-8')

In [None]:
! pip install pymorphy2

Collecting pymorphy2
  Downloading pymorphy2-0.9.1-py3-none-any.whl (55 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.5/55.5 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dawg-python>=0.7.1 (from pymorphy2)
  Downloading DAWG_Python-0.7.2-py2.py3-none-any.whl (11 kB)
Collecting pymorphy2-dicts-ru<3.0,>=2.4 (from pymorphy2)
  Downloading pymorphy2_dicts_ru-2.4.417127.4579844-py2.py3-none-any.whl (8.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docopt>=0.6 (from pymorphy2)
  Downloading docopt-0.6.2.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: docopt
  Building wheel for docopt (setup.py) ... [?25l[?25hdone
  Created wheel for docopt: filename=docopt-0.6.2-py2.py3-none-any.whl size=13706 sha256=f881c6934cac954d60d75185cb413ed813f461d7ae082cc186e5f53fe8eb85b4
  Stored in directory: /root/.

In [None]:
!pip install spacy
!python -m spacy download ru_core_news_sm

Collecting ru-core-news-sm==3.7.0
  Downloading https://github.com/explosion/spacy-models/releases/download/ru_core_news_sm-3.7.0/ru_core_news_sm-3.7.0-py3-none-any.whl (15.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.3/15.3 MB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
Collecting pymorphy3>=1.0.0 (from ru-core-news-sm==3.7.0)
  Downloading pymorphy3-2.0.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.2/53.2 kB[0m [31m911.4 kB/s[0m eta [36m0:00:00[0m
Collecting pymorphy3-dicts-ru (from pymorphy3>=1.0.0->ru-core-news-sm==3.7.0)
  Downloading pymorphy3_dicts_ru-2.4.417150.4580142-py2.py3-none-any.whl (8.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pymorphy3-dicts-ru, pymorphy3, ru-core-news-sm
Successfully installed pymorphy3-2.0.1 pymorphy3-dicts-ru-2.4.417150.4580142 ru-core-news-sm-3.7.0
[38;5;2m

In [None]:
import spacy
import pymorphy2

nlp = spacy.load('ru_core_news_sm')
morph = pymorphy2.MorphAnalyzer()

def match_case_and_number(original_phrase, new_phrase):
    doc_original = nlp(original_phrase)
    main_word_token = None

    # Поиск главного существительного в исходном словосочетании
    for token in doc_original:
        if token.pos_ == 'NOUN':
            main_word_token = token
            break

    # Если главное существительное не найдено, возвращаем новое словосочетание
    if not main_word_token:
        return new_phrase

    parsed_main_word = morph.parse(main_word_token.text)[0]
    case = parsed_main_word.tag.case
    number = parsed_main_word.tag.number

    doc_new = nlp(new_phrase)
    inflected_phrase = []

    for token in doc_new:
        parsed_new_word = morph.parse(token.text)[0]
        if token.pos_ == 'NOUN' or token.pos_ == 'ADJ':
            # Инфлектируем только если это главное существительное
            if token.dep_ == 'ROOT' or token.head.pos_ != 'NOUN' or token.pos_ == 'ADJ':
                if case and number:
                    inflected_params = {case, number} - {None}
                    inflected_word = parsed_new_word.inflect(inflected_params)
                    inflected_phrase.append(inflected_word.word if inflected_word else token.text)
                else:
                    inflected_phrase.append(token.text)
            else:
                inflected_phrase.append(token.text)
        else:
            inflected_phrase.append(token.text)

    return ' '.join(inflected_phrase)

result = match_case_and_number('реакция аллергическая', 'снижением иммунитета')
print(result)

снижение иммунитета


In [None]:
match_case_and_number('головной болью', 'сильная мигрень')

'сильной мигренью'

In [None]:
match_case_and_number('снижением иммунитета', 'аллергия')

'аллергией'

### Замена с инфлектированием по числу и падежу

In [None]:
def augment_text_and_update_indices(df, synonyms_df):
    augmented_data = []
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        text = row['text']
        updated_entities = []
        offset = 0

        for entity in row['entities']:
            if 'concept_id' in entity:
                concept_id = entity['concept_id']
                synonyms = synonyms_df[synonyms_df['concept_id'] == concept_id]['entity_text'].tolist()
                if synonyms:
                    synonym = random.choice(synonyms)
                    original_word = text[entity['start'] + offset : entity['end'] + offset]
                    matched_synonym = match_case_and_number(original_word, synonym)
                    new_text_length = len(matched_synonym)
                    text = text[:entity['start'] + offset] + matched_synonym + text[entity['end'] + offset:]
                    updated_entities.append({
                        "start": entity['start'] + offset,
                        "entity_type": entity['entity_type'],
                        "end": entity['start'] + offset + new_text_length,
                        "entity_id": entity['entity_id'],
                        "entity_text": matched_synonym,
                        "concept_id": entity['concept_id'],
                        "concept_name": entity.get('concept_name')
                    })
                    offset += new_text_length - (entity['end'] - entity['start'])

        augmented_data.append({
            "file_name": row['file_name'],
            "text": text,
            "entities": updated_entities,
            "sentence_id": row['sentence_id']
        })
    return augmented_data

In [None]:
umls_df = pd.read_csv('umls_concepts.csv')

In [None]:
augmented_data = augment_text_and_update_indices(data, umls_df)

100%|██████████| 4809/4809 [00:54<00:00, 87.93it/s] 


In [None]:
file_path = 'augmented_umls.json'

with open(file_path, 'w', encoding='utf-8') as f:
    for item in augmented_data:
        json.dump(item, f, ensure_ascii=False)
        f.write('\n')

In [None]:
augmented_data = augment_text_and_update_indices(data, synonyms_df)

100%|██████████| 4809/4809 [01:02<00:00, 77.31it/s] 


In [None]:
file_path = 'augmented_synonyms.json'

with open(file_path, 'w', encoding='utf-8') as f:
    for item in augmented_data:
        json.dump(item, f, ensure_ascii=False)
        f.write('\n')

## BERT

In [None]:
!pip install transformers



In [None]:
from transformers import pipeline, AutoModelForMaskedLM, AutoTokenizer
import random
import re

model_name = "DeepPavlov/rubert-base-cased"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)

def augment_text_and_entities(text, entities):
    augmented_entities = []
    text_tokens = re.findall(r'\b\w+\b', text, flags=re.UNICODE)

    if random.choice([True, False]) and entities:
        entity = random.choice(entities)
        start, end = entity['start'], entity['end']
        masked_text = text[:start] + tokenizer.mask_token + text[end:]
        outputs = fill_mask(masked_text)

        if outputs:
            for output in outputs:
                if output['sequence'].strip() != text.strip():
                    new_entity_text = output['token_str']
                    delta = len(new_entity_text) - (end - start)
                    text = text[:start] + new_entity_text + text[end:]
                    for ent in entities:
                        if ent['start'] > end:
                            ent['start'] += delta
                            ent['end'] += delta
                    entity['end'] = start + len(new_entity_text)
                    entity['entity_text'] = new_entity_text
                    break
        augmented_entities.extend([dict(ent) for ent in entities])
    else:
        non_entity_word_indexes = [i for i, token in enumerate(text_tokens) if all(i < ent['start'] or i > ent['end'] for ent in entities)]
        if non_entity_word_indexes:
            replace_index = random.choice(non_entity_word_indexes)
            word_start = text.find(text_tokens[replace_index])
            word_end = word_start + len(text_tokens[replace_index])
            masked_text = text[:word_start] + tokenizer.mask_token + text[word_end:]
            outputs = fill_mask(masked_text)

            if outputs:
                for output in outputs:
                    if output['sequence'].strip() != text.strip():
                        new_word = output['token_str']
                        delta = len(new_word) - len(text_tokens[replace_index])
                        text = text[:word_start] + new_word + text[word_end:]
                        for ent in entities:
                            if ent['start'] > word_end:
                                ent['start'] += delta
                                ent['end'] += delta
                        break
        augmented_entities.extend([dict(ent) for ent in entities])

    return text, augmented_entities

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/642 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/714M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/1.65M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:
def apply_augmentation(row):
    augmented_text, augmented_entities = augment_text_and_entities(row['text'], row['entities'])
    row['text'] = augmented_text
    row['entities'] = augmented_entities
    return row

augmented_data = data.apply(apply_augmentation, axis=1)

In [None]:
augmented_data.head(5)

Unnamed: 0,file_name,text,entities,sentence_id
0,172744.tsv,"нам прописали, так мой ребенок сыпью покрылся,...","[{'start': 123, 'entity_type': 'Drugform', 'en...",0
1,172744.tsv,Общее впечатление : я подошел\n,[],1
2,592814.tsv,Пила этот препарат для женщин 5 лет назад.\n,"[{'start': 23, 'entity_type': 'DI', 'end': 29,...",0
3,592814.tsv,"Так как начала работать на аптеке, начала част...",[],1
4,592814.tsv,В месяц по нескольку раз причем со всеми вытек...,"[{'start': 66, 'entity_type': 'DI', 'end': 72,...",2


In [None]:
augmented_data.to_json('augmented_bert.json', orient='records', lines=True, force_ascii=False)

In [None]:
with open(augmented_data_path, 'w', encoding='utf-8') as f_out:
    for line in tqdm(all_lines):
        json.dump(line, f_out, ensure_ascii=False)
        f_out.write('\n')

100%|██████████| 4809/4809 [00:00<00:00, 24681.34it/s]
