In [None]:
import os, sys

BASE_PATH = ''

os.chdir(BASE_PATH)

In [11]:
import re
import yaml
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

In [3]:
with open('data/mkb2descr_new.yaml', 'r') as f:
    mkb2descr = yaml.safe_load(f)
descr2mkb = {f'{v} - {k}': k for k, v in mkb2descr.items()}
len(mkb2descr), len(descr2mkb)

(17762, 17762)

In [49]:
with open('data/idx2mkb.yaml', 'r') as f:
    idx2mkb = yaml.safe_load(f)
mkb2idx = {mkb: idx for idx, mkb in idx2mkb.items()}
len(mkb2idx)

17312

In [102]:
def transform_predicted_idx(idx: str) -> int:
    idx = re.sub("[^0-9]", "", idx)

    if len(idx) == 0:
        return 1

    if len(idx) == 1:
        if int(idx[0]) == 0:
            return 1
        else:
            return int(idx)

    while int(idx[0]) == 0:
        if len(idx) == 1:
            return 1
        idx = idx[1:]

    if int(idx[0]) > 1:
        return int(idx[0])
    
    idx = idx[:2]
    return int(idx)

In [103]:
import difflib
def find_similar_substring(text, pattern, max_distance=2):
    # Разбиваем текст на подстроки длиной, равной длине шаблона
    for i in range(len(text) - len(pattern) + 1):
        substring = text[i:i+len(pattern)]
        # Вычисляем расстояние Левенштейна
        distance = difflib.SequenceMatcher(None, substring, pattern).ratio()
        # Если расстояние меньше или равно max_distance, возвращаем позицию
        if distance >= 1 - (max_distance / len(pattern)):
            return i
    return -1

def form_ann_str(text: str, preds: dict, mkb2idx: dict) -> str:
    ann_str = ''
    offset = 0
    grouped_preds = list(zip(preds['entities'], preds['mkbs']))
    for i, (entity, code) in enumerate(grouped_preds, start=1):

        start = text.find(entity, offset)
        if start == -1:
            start = find_similar_substring(text, entity)
        end = start + len(entity)
        if start > 0:
            offset = end

        ann_str += f"T{i}\ticd_code {start} {end}\t{entity}\n"
        ann_str += f"N{i}\tReference T{i} ICD_codes:{mkb2idx.get(code, -1)}\t{code}\n"

    return ann_str

def group_case(case):
    return (
        case['text'].iloc[0], 
        {
            'entities': case['entity'].to_list(),
            'mkbs': case['pred_mkb'].to_list()
        }
    )

In [6]:
oper_test_df = pd.read_parquet('data/oper_test_df_v2.parquet')
print(oper_test_df.shape)
oper_test_df.head()

(500, 5)


Unnamed: 0,text,text_labels,labled_entities,mkb_texts,mkb_codes
0,Аднексит слева ?Кисты экзоцервикса.,<mkb code=N70.9>Аднексит</mkb> слева ?<mkb co...,<mkb>Аднексит</mkb> слева ?<mkb>Кисты экзоцер...,"[Аднексит, Кисты экзоцервикса]","[N70.9, N88.8]"
1,Эндоцервицит?. Вагинит.? Хр. сальпингоофорит в...,Эндоцервицит?. Вагинит.? Хр. сальпингоофорит в...,Эндоцервицит?. Вагинит.? Хр. сальпингоофорит в...,"[ Гиперандрогени, Вагинит, сальпингоофорит, Хр...","[E28.1, N76.0, N70.1, N88, D25.9, N80.0, N71.1..."
2,"ОРВИ. Острый ринофарингиотрахеит, течение",<mkb code=J00-J06>ОРВИ</mkb>. <mkb code=J00>Ос...,<mkb>ОРВИ</mkb>. <mkb>Острый ринофарингиотрахе...,"[ОРВИ, Острый ринофарингиотрахеит]","[J00-J06, J00]"
3,"Острый трахеит , течение на фоне ОРВИ, острый...","<mkb code=J04.1>Острый трахеит</mkb> , течение...","<mkb>Острый трахеит</mkb> , течение на фоне <...","[Острый трахеит, ОРВИ, острый назофарингит, Д...","[J04.1, J00-J06, J00, L30.9]"
4,Дорсопатия шейного отдела позвоночника на фоне...,<mkb code=M50.8>Дорсопатия шейного отдела позв...,<mkb>Дорсопатия шейного отдела позвоночника</m...,"[Дорсопатия шейного отдела позвоночника, деген...","[M50.8, M42.1, M79.1, M54.2, H81.9]"


In [48]:
text_to_idx = {text: idx for idx, text in enumerate(oper_test_df['text'].to_list())}
len(text_to_idx)

500

In [105]:
def save_model_predicts(
        linking_rag_df_path: str,
        linking_pred_df_path: str,
        save_path: str,
        text2idx: dict,
        mkb2idx: dict,
    ):
    linking_rag_df = pd.read_parquet(linking_rag_df_path)
    linking_pred_df = pd.read_parquet(linking_pred_df_path)

    linking_pred_df['predicted_idx'] = linking_pred_df['predicted_idx'].apply(transform_predicted_idx)
    linking_pred_df['founded_mkbs'] = linking_rag_df['founded_mkbs']
    linking_pred_df['pred_mkb'] = linking_pred_df.apply(
        lambda case: case['founded_mkbs'][
            min(
                len(case['founded_mkbs']) - 1, 
                case['predicted_idx'] - 1
            )
        ], axis=1
    )

    preds_ds = linking_pred_df[['text', 'entity', 'pred_mkb']].groupby('text').apply(group_case)
    preds_ds = dict(preds_ds.to_list())

    for text, idx in text2idx.items():
        preds = preds_ds.get(text, {'entities': [], 'mkbs': []})
        ann_str = form_ann_str(text, preds, mkb2idx)

        with open(os.path.join(save_path, f'{idx}.txt'), 'w') as f:
            f.write(text)
        with open(os.path.join(save_path, f'{idx}.ann'), 'w') as f:
            f.write(ann_str)

In [97]:
preds_paths = [
    (
        'data/preds_retrieval_ds/llama/icd_retrieval_df.parquet',
        'results/linking_preds/icd/m42-health/Llama3-Med42-8B/icd.parquet',
        'results/preds/Llama3-Med42-8B_ner_icd_linking'
    ),
    (
        'data/preds_retrieval_ds/llama/tr_and_icd_retrieval_df.parquet',
        'results/linking_preds/train_and_icd/m42-health/Llama3-Med42-8B/train_and_icd.parquet',
        'results/preds/Llama3-Med42-8B_ner_train_and_icd_linking'
    ),
    (
        'data/preds_retrieval_ds/mistral/icd_retrieval_df.parquet',
        'results/linking_preds/icd/mistralai/Mistral-Nemo-Instruct-2407/icd.parquet',
        'results/preds/Mistral-Nemo_ner_icd_linking'
    ),
    (
        'data/preds_retrieval_ds/mistral/tr_and_icd_retrieval_df.parquet',
        'results/linking_preds/train_and_icd/mistralai/Mistral-Nemo-Instruct-2407/train_and_icd.parquet',
        'results/preds/Mistral-Nemo_ner_train_and_icd_linking'
    ),
    (
        'data/preds_retrieval_ds/phi/icd_retrieval_df.parquet',
        'results/linking_preds/icd/microsoft/Phi-3.5-mini-instruct/icd.parquet',
        'results/preds/Phi3_5_mini_ner_icd_linking'
    ),
    (
        'data/preds_retrieval_ds/phi/tr_and_icd_retrieval_df.parquet',
        'results/linking_preds/train_and_icd/microsoft/Phi-3.5-mini-instruct/train_and_icd.parquet',
        'results/preds/Phi3_5_mini_ner_train_and_icd_linking'
    ),
    (
        'data/preds_retrieval_ds/qwen/icd_retrieval_df.parquet',
        'results/linking_preds/icd/Qwen/Qwen2.5-7B-Instruct/icd.parquet',
        'results/preds/Qwen2.5-7B-Instruct_ner_icd_linking'
    ),
    (
        'data/preds_retrieval_ds/qwen/tr_and_icd_retrieval_df.parquet',
        'results/linking_preds/train_and_icd/Qwen/Qwen2.5-7B-Instruct/train_and_icd.parquet',
        'results/preds/Qwen2.5-7B-Instruct_ner_train_and_icd_linking'
    ),
]

In [106]:
for retrieval_ds_path, preds_ds_path, save_path in tqdm(preds_paths):
    save_model_predicts(retrieval_ds_path, preds_ds_path, save_path, text_to_idx, mkb2idx)

  0%|          | 0/8 [00:00<?, ?it/s]

  preds_ds = linking_pred_df[['text', 'entity', 'pred_mkb']].groupby('text').apply(group_case)
  preds_ds = linking_pred_df[['text', 'entity', 'pred_mkb']].groupby('text').apply(group_case)
  preds_ds = linking_pred_df[['text', 'entity', 'pred_mkb']].groupby('text').apply(group_case)
  preds_ds = linking_pred_df[['text', 'entity', 'pred_mkb']].groupby('text').apply(group_case)
  preds_ds = linking_pred_df[['text', 'entity', 'pred_mkb']].groupby('text').apply(group_case)
  preds_ds = linking_pred_df[['text', 'entity', 'pred_mkb']].groupby('text').apply(group_case)
  preds_ds = linking_pred_df[['text', 'entity', 'pred_mkb']].groupby('text').apply(group_case)
  preds_ds = linking_pred_df[['text', 'entity', 'pred_mkb']].groupby('text').apply(group_case)
