### load ICD10

In [10]:
import networkx as nx
from tqdm import tqdm
from ICD10 import ICD10

In [11]:
ICD10_PATH = "/home/amfierens/Documents/Researcher/ICD10/icd10OrderFiles/icd10cm_order_2023.txt"
icd10 = ICD10(ICD10_PATH)
icd10.load_icd10()

In [12]:
icd10_sf_id_pairs = []

for icd_code in tqdm(icd10.graph.nodes):
    
    node_descs = icd10.index_definition[icd_code]
    for d in node_descs:
        icd10_sf_id_pairs.append((d, icd_code))

print(len(icd10_sf_id_pairs))

100%|██████████| 96795/96795 [00:00<00:00, 1172480.12it/s]

96795





In [13]:
icd10_sf_id_pairs[:10]

[('Cholera', 'A00'),
 ('Cholera due to Vibrio cholerae 01, biovar cholerae', 'A000'),
 ('Cholera due to Vibrio cholerae 01, biovar eltor', 'A001'),
 ('Cholera, unspecified', 'A009'),
 ('Typhoid and paratyphoid fevers', 'A01'),
 ('Typhoid fever', 'A010'),
 ('Typhoid fever, unspecified', 'A0100'),
 ('Typhoid meningitis', 'A0101'),
 ('Typhoid fever with heart involvement', 'A0102'),
 ('Typhoid pneumonia', 'A0103')]

In [14]:
icd10_sf_id_pairs = icd10_sf_id_pairs[:50000] 

In [15]:
all_names = [p[0] for p in icd10_sf_id_pairs]
all_ids = [p[1] for p in icd10_sf_id_pairs]

In [16]:
all_names[:10]

['Cholera',
 'Cholera due to Vibrio cholerae 01, biovar cholerae',
 'Cholera due to Vibrio cholerae 01, biovar eltor',
 'Cholera, unspecified',
 'Typhoid and paratyphoid fevers',
 'Typhoid fever',
 'Typhoid fever, unspecified',
 'Typhoid meningitis',
 'Typhoid fever with heart involvement',
 'Typhoid pneumonia']

In [17]:
print(icd10.predecessors(all_ids[1]))
all_ids[:10]


['A00']


['A00',
 'A000',
 'A001',
 'A009',
 'A01',
 'A010',
 'A0100',
 'A0101',
 'A0102',
 'A0103']

### load sapbert

In [18]:
!pip install transformers
!pip install torch



In [19]:
from transformers import AutoTokenizer, AutoModel  
tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")  
model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext") #.cuda(1)

In [20]:
import numpy as np
import torch

### encode ICD10 labels

In [21]:
bs = 32
all_reps = []
for i in tqdm(np.arange(0, len(all_names), bs)):
    toks = tokenizer.batch_encode_plus(all_names[i:i+bs], 
                                       padding="max_length", 
                                       max_length=25, 
                                       truncation=True,
                                       return_tensors="pt")
    #toks_cuda = {}
    #for k,v in toks.items():
    #    toks_cuda[k] = v.cuda(1)
    #output = model(**toks_cuda)
    
    output = model(**toks)
    cls_rep = output[0][:,0,:]
    
    all_reps.append(cls_rep.cpu().detach().numpy())
all_reps_emb = np.concatenate(all_reps, axis=0)

  0%|          | 0/1563 [00:00<?, ?it/s]

100%|██████████| 1563/1563 [21:31<00:00,  1.21it/s]


In [22]:
print (all_reps_emb.shape)

(50000, 768)


### encode query

In [32]:
query = "arthritis"
query_toks = tokenizer.batch_encode_plus([query], 
                                       padding="max_length", 
                                       max_length=25, 
                                       truncation=True,
                                       return_tensors="pt")

In [33]:
query_output = model(**query_toks)
query_cls_rep = query_output[0][:,0,:]

In [34]:
query_cls_rep.shape

torch.Size([1, 768])

### find query's nearest neighbour

In [35]:
# for large-scale search, should switch to faiss
from scipy.spatial.distance import cdist

In [36]:
dist = cdist(query_cls_rep.cpu().detach().numpy(), all_reps_emb)
nn_index = np.argmin(dist)
print ("predicted label:", icd10_sf_id_pairs[nn_index])

predicted label: ('Arthropathy, unspecified', 'M129')


In [38]:
indexes = np.argsort(dist[0])[:10]

for i, index in enumerate(indexes):
    print("====================")
    print (f"Predicted label {i}:", icd10_sf_id_pairs[index])
    for k, pred in enumerate(icd10.predecessors(icd10_sf_id_pairs[index][1])):
        print (f"\tpredecessor {k}:", pred, icd10[pred])

Predicted label 0: ('Arthropathy, unspecified', 'M129')
	predecessor 0: M12 {'desc': 'Other and unspecified arthropathy'}
Predicted label 1: ('Rheumatism, unspecified', 'M790')
	predecessor 0: M79 {'desc': 'Oth and unsp soft tissue disorders, not elsewhere classified'}
Predicted label 2: ('Osteoarthritis, unspecified site', 'M199')
	predecessor 0: M19 {'desc': 'Other and unspecified osteoarthritis'}
Predicted label 3: ('Other arthritis', 'M13')
Predicted label 4: ('Pain in joint', 'M255')
	predecessor 0: M25 {'desc': 'Other joint disorder, not elsewhere classified'}
Predicted label 5: ('Pain in unspecified joint', 'M2550')
	predecessor 0: M255 {'desc': 'Pain in joint'}
Predicted label 6: ('Unspecified osteoarthritis, unspecified site', 'M1990')
	predecessor 0: M199 {'desc': 'Osteoarthritis, unspecified site'}
Predicted label 7: ('Joint disorder, unspecified', 'M259')
	predecessor 0: M25 {'desc': 'Other joint disorder, not elsewhere classified'}
Predicted label 8: ('Reactive arthropathy

In [29]:
print(icd10.predecessors('J45998'))
icd10['J4599']

['J4599']


{'desc': 'Other asthma'}