### Performance of Classification Between Cosine Similarity of Vector Embeddings and LLM

##### Load Data

In [1]:
from datasets import load_dataset, concatenate_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
icd_dataset = load_dataset('krishnareddy/icddxdescmap', trust_remote_code = True)

icd_train = icd_dataset['train'].to_pandas()
icd_validation = icd_dataset['validation'].to_pandas()
icd_test = icd_dataset['test'].to_pandas()

icd_concat = concatenate_datasets([
    icd_dataset['train'],
    icd_dataset['validation'],
    icd_dataset['test']
])

icd_data = icd_concat.to_pandas()
icd_data.head()

Unnamed: 0,docdesc,dxcode,shortdesc,longdesc
0,12 week IUP,Z3A.12,12 weeks gestation of pregnancy,12 weeks gestation of pregnancy
1,14 weeks pregnant,Z3A.14,14 weeks gestation of pregnancy,14 weeks gestation of pregnancy
2,15 weeks pregnant,Z3A.15,15 weeks gestation of pregnancy,15 weeks gestation of pregnancy
3,17 wks pregnant,Z3A.17,17 weeks gestation of pregnancy,17 weeks gestation of pregnancy
4,2 weeks pregnant,Z3A.20,20 weeks gestation of pregnancy,20 weeks gestation of pregnancy


##### Clean

In [21]:
# Remove '.' in the dataset to much the ICD code set that we have
icd_data['dxcode'] = icd_data['dxcode'].str.replace('.', '', regex = False)
icd_train['dxcode'] = icd_train['dxcode'].str.replace('.', '', regex = False)
icd_validation['dxcode'] = icd_validation['dxcode'].str.replace('.', '', regex = False)
icd_test['dxcode'] = icd_test['dxcode'].str.replace('.', '', regex = False)

icd_data.head()

Unnamed: 0,docdesc,dxcode,shortdesc,longdesc
0,12 week IUP,Z3A12,12 weeks gestation of pregnancy,12 weeks gestation of pregnancy
1,14 weeks pregnant,Z3A14,14 weeks gestation of pregnancy,14 weeks gestation of pregnancy
2,15 weeks pregnant,Z3A15,15 weeks gestation of pregnancy,15 weeks gestation of pregnancy
3,17 wks pregnant,Z3A17,17 weeks gestation of pregnancy,17 weeks gestation of pregnancy
4,2 weeks pregnant,Z3A20,20 weeks gestation of pregnancy,20 weeks gestation of pregnancy


In [4]:
icd_data

Unnamed: 0,docdesc,dxcode,shortdesc,longdesc
0,12 week IUP,Z3A12,12 weeks gestation of pregnancy,12 weeks gestation of pregnancy
1,14 weeks pregnant,Z3A14,14 weeks gestation of pregnancy,14 weeks gestation of pregnancy
2,15 weeks pregnant,Z3A15,15 weeks gestation of pregnancy,15 weeks gestation of pregnancy
3,17 wks pregnant,Z3A17,17 weeks gestation of pregnancy,17 weeks gestation of pregnancy
4,2 weeks pregnant,Z3A20,20 weeks gestation of pregnancy,20 weeks gestation of pregnancy
...,...,...,...,...
71541,VOMITING OR OTHER SEVERE,R1110,"Vomiting, unspecified","Vomiting, unspecified"
71542,"vomiting, vaginal",R1110,"Vomiting, unspecified","Vomiting, unspecified"
71543,"vulvovaginitis, pyelonephritis",N760,Acute vaginitis,Acute vaginitis
71544,warfarin hyperlipidemia,E785,"Hyperlipidemia, unspecified","Hyperlipidemia, unspecified"


##### Load ICD Embeddings

In [5]:
from helpers.icd import open_icd_embeddings

embeddings = open_icd_embeddings('pritamdeka')

2025-05-12 19:13:46.338602: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-12 19:13:46.349191: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747048426.363569  751831 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747048426.367990  751831 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747048426.379806  751831 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

##### Make Predictions

In [6]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('pritamdeka/S-PubMedBert-MS-MARCO')

##### Embed Text from Dataset

In [46]:
# Remove data points with only 1 instance of 
icd_counts = icd_data['dxcode'].value_counts()
icd_data_trimmed = icd_data[icd_data['dxcode'].isin(icd_counts[icd_counts > 1].index)]

icd_data_one_point = icd_data[icd_data['dxcode'].isin(icd_counts[icd_counts == 1].index)]

In [54]:
from sklearn.model_selection import train_test_split
icd_train_stratified, icd_test_stratified = train_test_split(
    icd_data_trimmed,
    test_size = 0.1,
    stratify = icd_data_trimmed['dxcode'],
    random_state = 2025
)

In [15]:
from sentence_transformers import SentenceTransformer
sentence_model = SentenceTransformer('pritamdeka/S-PubMedBert-MS-MARCO', device = 'cuda')

In [56]:
from helpers.icd import get_top_k_similar
import pandas as pd
import torch

def run_predictions(data, model):
    results = []

    for row in data.itertuples():
        encoding = model.encode(row.docdesc)

        codes, sims = get_top_k_similar(encoding, embeddings)

        new_row = {
            'true_code': row.dxcode,
            'text': row.docdesc,
            'predicted_code': codes[0],
            'similarity_score': sims[0],
            'predicted': int(row.dxcode == codes[0]),
            'in_top_5': int(row.dxcode in codes)
        }

        results.append(new_row)

    return pd.DataFrame(results)

##### Run Analysis

In [58]:
sentence_analysis_stratified = run_predictions(icd_test_stratified, sentence_model)

sentence_analysis_one_point = run_predictions(icd_data_one_point, sentence_model)

##### Saved the Analysis

In [59]:
sentence_analysis.to_csv('./Saved_Data/icd_stratified.csv', index = False)
sentence_analysis.to_csv('./Saved_Data/icd_one_point.csv', index = False)

In [33]:
icd_data['dxcode']

0          Z3A12
1          Z3A14
2          Z3A15
3          Z3A17
4          Z3A20
          ...   
71541      R1110
71542      R1110
71543       N760
71544       E785
71545    S383XXA
Name: dxcode, Length: 71546, dtype: object