In [None]:
import os, sys

BASE_PATH = ''

os.chdir(BASE_PATH)

In [19]:
import re
import faiss
import yaml
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from FlagEmbedding import BGEM3FlagModel

In [26]:
with open('data/abbr2term.yaml', 'r') as f:
    abbr2term = yaml.safe_load(f)
len(abbr2term)

4506

In [3]:
with open('data/mkb2descr_new.yaml', 'r') as f:
    mkb2descr = yaml.safe_load(f)
descr2mkb = {f'{v} - {k}': k for k, v in mkb2descr.items()}
len(mkb2descr), len(descr2mkb)

(17762, 17762)

In [16]:
embedder = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [10]:
icd_retrieval_df = pd.read_parquet('data/mkb_descr_and_emb_df.parquet')
icd_retrieval_df.head()

Unnamed: 0,text,mkb,emb
0,Коды для использования при отсутствии диагноза...,00,"[-0.03753662109375, -0.0203399658203125, -0.06..."
1,Холера - A00,A00,"[0.0167694091796875, 0.006465911865234375, -0...."
2,КИШЕЧНЫЕ ИНФЕКЦИИ (A00-A09) - A00-A09,A00-A09,"[-0.0097808837890625, 0.002227783203125, -0.05..."
3,НЕКОТОРЫЕ ИНФЕКЦИОННЫЕ И ПАРАЗИТАРНЫЕ БОЛЕЗНИ ...,A00-B99,"[0.038055419921875, 0.0011386871337890625, -0...."
4,"Холера, вызванная холерным вибрионом 01, биова...",A00.0,"[0.00910186767578125, -0.004512786865234375, -..."


In [11]:
train_mkb_texts_and_emb_df = pd.read_parquet('data/train_mkb_texts_and_emb_df.parquet')
train_mkb_text2emb = train_mkb_texts_and_emb_df.set_index('text')['emb'].to_dict()
print(len(train_mkb_text2emb))
train_mkb_texts_and_emb_df.head()

4517


Unnamed: 0,text,text_without_abbr,mkb,emb
0,Эндоцервикоз,Эндоцервикоз,N88.8,"[-0.025390625, 0.06640625, -0.011138916015625,..."
1,Хронический сальпингит и оофорит,Хронический сальпингит и оофорит,N70.1,"[-0.04779052734375, 0.01415252685546875, -0.05..."
2,Спаечный процесс в малом тазу,Спаечный процесс в малом тазу,N73.6,"[0.00469207763671875, 0.0288543701171875, -0.0..."
3,Ожирение,Ожирение,E66.9,"[-0.01204681396484375, 0.0248260498046875, -0...."
4,Полип цервикального канала.,Полип цервикального канала.,N84.1,"[-0.054534912109375, 0.0253448486328125, -0.04..."


In [12]:
train_and_icd_retrieval_df = pd.concat([
    train_mkb_texts_and_emb_df[['text', 'mkb', 'emb']],
    icd_retrieval_df[['text', 'mkb', 'emb']]
])
print(train_and_icd_retrieval_df.shape)

(26116, 3)


In [13]:
icd_retrieval_embeddings = np.vstack(icd_retrieval_df['emb'].to_list())
icd_retrieval_dimension = icd_retrieval_embeddings.shape[1]
icd_retrieval_index = faiss.IndexFlatL2(icd_retrieval_dimension)
icd_retrieval_index.add(icd_retrieval_embeddings)

In [14]:
train_and_icd_retrieval_embeddings = np.vstack(train_and_icd_retrieval_df['emb'].to_list())
train_and_icd_retrieval_dimension = train_and_icd_retrieval_embeddings.shape[1]
train_and_icd_retrieval_index = faiss.IndexFlatL2(train_and_icd_retrieval_dimension)
train_and_icd_retrieval_index.add(train_and_icd_retrieval_embeddings)

In [20]:
def convert_string_to_list(s):
    # Удаляем внешние квадратные скобки и разделяем элементы по кавычкам и запятым
    elements = re.findall(r"'([^']*)'", s)
    return elements

def load_oper_df(df_path):
    pred_df = pd.read_parquet(df_path)
    pred_df['predicted_entities'] = pred_df['predicted_entities'].apply(convert_string_to_list)
    oper_pred_df = pred_df.explode('predicted_entities')
    return oper_pred_df

In [27]:
def ABBREV_to_term(text: str, abbr2term: dict) -> str:
    """ Метод осуществляющий поиск и замену аббревиатур в строке на их расшифровки.
    На вход получает текст, возвращает текст с расшифрованными аббревиатурами
    """
    abbr_pattern = r'\b[А-Я]{2,8}\b'
    abbrs = re.findall(abbr_pattern, text)
    for abbr in abbrs:
        if abbr in abbr2term:
            text = re.sub(abbr, abbr2term[abbr], text)
    return text

In [28]:
def create_oper_retrieval_df(
        embedder,
        initial_df,
        retrieval_df,
        retrieval_index,
        abbr2term,
        n
    ):
    oper_retrieval_ds = []
    for text, entity in tqdm(initial_df.to_records(index=False)):
        if isinstance(entity, str):
            text_emb = embedder.encode(ABBREV_to_term(entity, abbr2term))['dense_vecs']
        else:
            continue
        _, idxs = retrieval_index.search(np.expand_dims(text_emb, axis=0), n)
        founded_records = retrieval_df.iloc[idxs[0].tolist()]
        founded_descr = founded_records['text'].to_list()
        founded_mkbs = founded_records['mkb'].to_list()

        oper_retrieval_ds.append({
            'initial_text': text,
            'mkb_text': entity,
            'founded_mkb_descriptions': founded_descr,
            'founded_mkbs': founded_mkbs,
            'label': 0
        })

    return pd.DataFrame(oper_retrieval_ds)

In [29]:
llama3_icd_retrieval_df = create_oper_retrieval_df(
    embedder,
    load_oper_df('results/preds/ner_preds/m42-health/Llama3-Med42-8B/v_0.parquet'),
    icd_retrieval_df,
    icd_retrieval_index,
    abbr2term,
    15
)

  0%|          | 0/1553 [00:00<?, ?it/s]

In [32]:
llama3_tr_and_icd_retrieval_df = create_oper_retrieval_df(
    embedder,
    load_oper_df('results/preds/ner_preds/m42-health/Llama3-Med42-8B/v_0.parquet'),
    train_and_icd_retrieval_df,
    train_and_icd_retrieval_index,
    abbr2term,
    15
)

  0%|          | 0/1553 [00:00<?, ?it/s]

In [35]:
phi3_5_icd_retrieval_df = create_oper_retrieval_df(
    embedder,
    load_oper_df('results/preds/ner_preds/microsoft/Phi-3.5-mini-instruct/v_0.parquet'),
    icd_retrieval_df,
    icd_retrieval_index,
    abbr2term,
    15
)

  0%|          | 0/1531 [00:00<?, ?it/s]

In [36]:
phi3_5_tr_and_icd_retrieval_df = create_oper_retrieval_df(
    embedder,
    load_oper_df('results/preds/ner_preds/microsoft/Phi-3.5-mini-instruct/v_0.parquet'),
    train_and_icd_retrieval_df,
    train_and_icd_retrieval_index,
    abbr2term,
    15
)

  0%|          | 0/1531 [00:00<?, ?it/s]

In [37]:
mistral_icd_retrieval_df = create_oper_retrieval_df(
    embedder,
    load_oper_df('results/preds/ner_preds/mistralai/Mistral-Nemo-Instruct-2407/v_0.parquet'),
    icd_retrieval_df,
    icd_retrieval_index,
    abbr2term,
    15
)

  0%|          | 0/1505 [00:00<?, ?it/s]

In [38]:
mistral_tr_and_icd_retrieval_df = create_oper_retrieval_df(
    embedder,
    load_oper_df('results/preds/ner_preds/mistralai/Mistral-Nemo-Instruct-2407/v_0.parquet'),
    train_and_icd_retrieval_df,
    train_and_icd_retrieval_index,
    abbr2term,
    15
)

  0%|          | 0/1505 [00:00<?, ?it/s]

In [39]:
qwen_icd_retrieval_df = create_oper_retrieval_df(
    embedder,
    load_oper_df('results/preds/ner_preds/Qwen/Qwen2.5-7B-Instruct/v_0.parquet'),
    icd_retrieval_df,
    icd_retrieval_index,
    abbr2term,
    15
)

  0%|          | 0/1589 [00:00<?, ?it/s]

In [40]:
qwen_tr_and_icd_retrieval_df = create_oper_retrieval_df(
    embedder,
    load_oper_df('results/preds/ner_preds/Qwen/Qwen2.5-7B-Instruct/v_0.parquet'),
    train_and_icd_retrieval_df,
    train_and_icd_retrieval_index,
    abbr2term,
    15
)

  0%|          | 0/1589 [00:00<?, ?it/s]

In [41]:
llama3_icd_retrieval_df.to_parquet('data/preds_retrieval_ds/llama/icd_retrieval_df.parquet')
llama3_tr_and_icd_retrieval_df.to_parquet('data/preds_retrieval_ds/llama/tr_and_icd_retrieval_df.parquet')

In [42]:
phi3_5_icd_retrieval_df.to_parquet('data/preds_retrieval_ds/phi/icd_retrieval_df.parquet')
phi3_5_tr_and_icd_retrieval_df.to_parquet('data/preds_retrieval_ds/phi/tr_and_icd_retrieval_df.parquet')

In [43]:
mistral_icd_retrieval_df.to_parquet('data/preds_retrieval_ds/mistral/icd_retrieval_df.parquet')
mistral_tr_and_icd_retrieval_df.to_parquet('data/preds_retrieval_ds/mistral/tr_and_icd_retrieval_df.parquet')

In [44]:
qwen_icd_retrieval_df.to_parquet('data/preds_retrieval_ds/qwen/icd_retrieval_df.parquet')
qwen_tr_and_icd_retrieval_df.to_parquet('data/preds_retrieval_ds/qwen/tr_and_icd_retrieval_df.parquet')