In [18]:
import numpy as np
import pandas as pd
import torch

import pickle
from sklearn.model_selection import train_test_split


from bert_wrapper import BERTWrapper
from scripts.probe import TwoWordPSDProbe
from scripts.probe_regimen import ProbeRegimen
from scripts.loss import L1DistanceLoss


Restricting the analysis to the selectable ICD-10 codes, meaning those that would actually be used in practice. The non-selectable codes represent broader categories

In [27]:
df = pd.read_csv('ICD10.csv')


df_codes = df[df['selectable'] == 'Y']
icd_codes = df_codes['coding'].tolist()
df_codes

Unnamed: 0,coding,meaning,node_id,parent_id,selectable
4,A000,"A00.0 Cholera due to Vibrio cholerae 01, biova...",287,286.0,Y
5,A001,"A00.1 Cholera due to Vibrio cholerae 01, biova...",288,286.0,Y
6,A009,"A00.9 Cholera, unspecified",289,286.0,Y
8,A010,A01.0 Typhoid fever,291,290.0,Y
9,A011,A01.1 Paratyphoid fever A,292,290.0,Y
...,...,...,...,...,...
19150,Z992,Z99.2 Dependence on renal dialysis,19150,19147.0,Y
19151,Z993,Z99.3 Dependence on wheelchair,19151,19147.0,Y
19152,Z994,Z99.4 Dependence on artificial heart,19152,19147.0,Y
19153,Z998,Z99.8 Dependence on other enabling machines an...,19153,19147.0,Y


In [21]:
df_codes.loc[:, 'text'] = df_codes.apply(
    lambda row: row['meaning'][row['meaning'][:10].rfind(row['coding'][-1])+2:] 
    if row['coding'][-1] in row['meaning'][:10] else row['meaning'], 
    axis=1
)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_codes.loc[:, 'text'] = df_codes.apply(


In [22]:
df_codes = df_codes[['coding', 'text', 'node_id', 'parent_id']]
df_codes

Unnamed: 0,coding,text,node_id,parent_id
4,A000,"Cholera due to Vibrio cholerae 01, biovar chol...",287,286.0
5,A001,"Cholera due to Vibrio cholerae 01, biovar el tor",288,286.0
6,A009,"Cholera, unspecified",289,286.0
8,A010,Typhoid fever,291,290.0
9,A011,Paratyphoid fever A,292,290.0
...,...,...,...,...
19150,Z992,Dependence on renal dialysis,19150,19147.0
19151,Z993,Dependence on wheelchair,19151,19147.0
19152,Z994,Dependence on artificial heart,19152,19147.0
19153,Z998,Dependence on other enabling machines and devices,19153,19147.0


### Text embedding of the ICD code names (*SKIP IF YOU'VE ALREADY DONE THIS*)

In [23]:
base_BERT = BERTWrapper(model_type="bert-base-uncased", random_init=False)
clinical_BERT = BERTWrapper(model_type="emilyalsentzer/Bio_ClinicalBERT", random_init=False)
random_BERT = BERTWrapper(random_init=True)

base_BERT_embeddings = base_BERT.embed_text_array(df_codes['text'].tolist())
clinical_BERT_embeddings = clinical_BERT.embed_text_array(df_codes['text'].tolist())
random_BERT_embeddings = random_BERT.embed_text_array(df_codes['text'].tolist())

base_BERT_dict = {code: base_BERT_embeddings[i] for i, code in enumerate(icd_codes)}
clinical_BERT_dict = {code: clinical_BERT_embeddings[i] for i, code in enumerate(icd_codes)}
random_BERT_dict = {code: random_BERT_embeddings[i] for i, code in enumerate(icd_codes)}

# Save each dictionary to its own file
with open("base_BERT_dict.pkl", "wb") as file:
    pickle.dump(base_BERT_dict, file)

with open("clinical_BERT_dict.pkl", "wb") as file:
    pickle.dump(clinical_BERT_dict, file)

with open("random_BERT_dict.pkl", "wb") as file:
    pickle.dump(random_BERT_dict, file)



In [35]:
# Load each dictionary
with open("base_BERT_dict.pkl", "rb") as file:
    base_BERT_dict = pickle.load(file)

with open("clinical_BERT_dict.pkl", "rb") as file:
    clinical_BERT_dict = pickle.load(file)

with open("random_BERT_dict.pkl", "rb") as file:
    random_BERT_dict = pickle.load(file)

In [83]:
file_path = "ICD10.csv"
df_tree = pd.read_csv(file_path, header=0)
icd_tree = ICDTree(df_tree)


df_codes = df_tree[df_tree['selectable'] == 'Y']
icd_codes = df_codes['coding'].tolist()

icd_tree = ICDTree(df_tree)

icd_codes = locally_shuffle(icd_codes, 100)

# Create held out test set 
icd_codes_train = icd_codes[:int(len(icd_codes)*0.8)]
icd_codes_test = icd_codes[int(len(icd_codes)*0.8):]

# From the remaining set, create the train and dev sets
icd_codes_train = icd_codes_train[:int(len(icd_codes_train)*0.8)]
icd_codes_dev = icd_codes_train[int(len(icd_codes_train)*0.8):]

probe_rank = 64

# Probe Training

Initialize ICD Tree:

In [84]:
probe1 = TwoWordPSDProbe(base_BERT_dict["A000"].shape[0], 64)

loss1 = L1DistanceLoss()

probe_regimen1 = ProbeRegimen(icd_tree)

probe_regimen1.train(probe1, loss1, icd_codes_train, icd_codes_dev, base_BERT_dict, batch_size=10, name="base")

Constructing TwoWordPSDProbe for a single large tree


[epoch 0] Training:   3%|▎         | 280/10438 [00:04<02:39, 63.82samples/s]


KeyboardInterrupt: 

In [82]:
probe_loaded1 = TwoWordPSDProbe(base_BERT_dict["A000"].shape[0], 64)  # Replace ProbeModel with your actual model class
probe_loaded1.load_state_dict(torch.load('saved_models/base_batch_size_10_epoch_9.pt'))
probe_loaded1.eval()  # Set the model to evaluation mode

probe_regimen = ProbeRegimen(icd_tree)

# Generate predictions
predictions, labels = probe_regimen1.predict(probe1, icd_codes_test, base_BERT_dict, batch_size=10)

correlations, mean_correlations = calculate_spearman_correlations(predictions, labels)
print(f"Mean Spearman correlation: {mean_correlations}")

Constructing TwoWordPSDProbe for a single large tree


[predicting batches]:   0%|          | 0/327 [00:00<?, ?it/s]

[predicting batches]: 100%|██████████| 327/327 [00:48<00:00,  6.73it/s]


Mean Spearman correlation: 0.4585213634205754


In [80]:
mean_correlations = evaluate_full_correlations(icd_tree, base_BERT_dict, icd_codes_test, batch_size=10, shuffle_icd=False)
print(f"Mean Spearman correlation: {mean_correlations}")

Evaluating batches: 100%|██████████| 327/327 [00:50<00:00,  6.53it/s]


Mean Spearman correlation: 0.498880413276541


# Clinical BERT

In [91]:
probe2 = TwoWordPSDProbe(clinical_BERT_dict["A000"].shape[0], 64)

loss2 = L1DistanceLoss()

probe_regimen2 = ProbeRegimen(icd_tree)

probe_regimen2.train(probe2, loss2, icd_codes_train, icd_codes_dev, clinical_BERT_dict, batch_size=10, name="clinical")

Constructing TwoWordPSDProbe for a single large tree


[epoch 0] Training: 100%|██████████| 10438/10438 [02:39<00:00, 65.46samples/s]
[epoch 0] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.77samples/s]


[epoch 0] Train loss: 1.520573702009245, Dev loss: 1.2910674453922435
[epoch 0] New best model saved at: saved_models/clinical_batch_size_10_epoch_0.pt


[epoch 1] Training: 100%|██████████| 10438/10438 [02:31<00:00, 68.79samples/s]
[epoch 1] Validation: 100%|██████████| 2088/2088 [00:31<00:00, 67.33samples/s]


[epoch 1] Train loss: 1.2414258820686304, Dev loss: 1.1819702227149853
[epoch 1] New best model saved at: saved_models/clinical_batch_size_10_epoch_1.pt


[epoch 2] Training: 100%|██████████| 10438/10438 [02:36<00:00, 66.49samples/s]
[epoch 2] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.93samples/s]


[epoch 2] Train loss: 1.150104395731199, Dev loss: 1.1113553891341652
[epoch 2] New best model saved at: saved_models/clinical_batch_size_10_epoch_2.pt


[epoch 3] Training: 100%|██████████| 10438/10438 [02:33<00:00, 67.98samples/s]
[epoch 3] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.23samples/s]


[epoch 3] Train loss: 1.0763883380026653, Dev loss: 1.0532372186058445
[epoch 3] New best model saved at: saved_models/clinical_batch_size_10_epoch_3.pt


[epoch 4] Training: 100%|██████████| 10438/10438 [02:34<00:00, 67.72samples/s]
[epoch 4] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 70.86samples/s]


[epoch 4] Train loss: 1.0453882981305835, Dev loss: 0.9980160285981649
[epoch 4] New best model saved at: saved_models/clinical_batch_size_10_epoch_4.pt


[epoch 5] Training: 100%|██████████| 10438/10438 [02:32<00:00, 68.61samples/s]
[epoch 5] Validation: 100%|██████████| 2088/2088 [00:30<00:00, 68.81samples/s]


[epoch 5] Train loss: 1.013547432690973, Dev loss: 0.9796607277610085
[epoch 5] New best model saved at: saved_models/clinical_batch_size_10_epoch_5.pt


[epoch 6] Training: 100%|██████████| 10438/10438 [02:40<00:00, 65.06samples/s]
[epoch 6] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.64samples/s]


[epoch 6] Train loss: 0.9859852891315446, Dev loss: 0.9902253205125983


[epoch 7] Training: 100%|██████████| 10438/10438 [02:32<00:00, 68.36samples/s]
[epoch 7] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.94samples/s]


[epoch 7] Train loss: 0.8453703537526258, Dev loss: 0.7689726697771173
[epoch 7] New best model saved at: saved_models/clinical_batch_size_10_epoch_7.pt


[epoch 8] Training: 100%|██████████| 10438/10438 [02:33<00:00, 67.90samples/s]
[epoch 8] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.00samples/s]


[epoch 8] Train loss: 0.7455706996593439, Dev loss: 0.7074550713933826
[epoch 8] New best model saved at: saved_models/clinical_batch_size_10_epoch_8.pt


[epoch 9] Training: 100%|██████████| 10438/10438 [02:37<00:00, 66.28samples/s]
[epoch 9] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.51samples/s]

[epoch 9] Train loss: 0.7054539959149799, Dev loss: 0.6739295970595054
[epoch 9] New best model saved at: saved_models/clinical_batch_size_10_epoch_9.pt





In [92]:
probe_loaded2 = TwoWordPSDProbe(clinical_BERT_dict["A000"].shape[0], 64)  # Replace ProbeModel with your actual model class
probe_loaded2.load_state_dict(torch.load('saved_models/clinical_batch_size_10_epoch_9.pt'))
probe_loaded2.eval()  # Set the model to evaluation mode

probe_regimen2 = ProbeRegimen(icd_tree)

# Generate predictions
predictions2, labels2 = probe_regimen2.predict(probe_loaded2, icd_codes_test, clinical_BERT_dict, batch_size=10)

correlations2, mean_correlations2 = calculate_spearman_correlations(predictions2, labels2)
print(f"Mean Spearman correlation: {mean_correlations2}")

Constructing TwoWordPSDProbe for a single large tree


[predicting batches]: 100%|██████████| 327/327 [00:46<00:00,  7.04it/s]

Mean Spearman correlation: 0.37179386244581386





In [93]:
mean_correlations = evaluate_full_correlations(icd_tree, clinical_BERT_dict, icd_codes_test, batch_size=10, shuffle_icd=False)
print(f"Mean Spearman correlation: {mean_correlations}")

Evaluating batches: 100%|██████████| 327/327 [00:49<00:00,  6.56it/s]

Mean Spearman correlation: 0.37709607757004865





# Random BERT

In [87]:
probe3 = TwoWordPSDProbe(random_BERT_dict["A000"].shape[0], 64)

loss3 = L1DistanceLoss()

probe_regimen3 = ProbeRegimen(icd_tree)

probe_regimen3.train(probe3, loss3, icd_codes_train, icd_codes_dev, random_BERT_dict, batch_size=10, name="random")

Constructing TwoWordPSDProbe for a single large tree


[epoch 0] Training: 100%|██████████| 10438/10438 [02:36<00:00, 66.59samples/s]
[epoch 0] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.22samples/s]


[epoch 0] Train loss: 1.6444519889537401, Dev loss: 1.4210989093096063
[epoch 0] New best model saved at: saved_models/random_batch_size_10_epoch_0.pt


[epoch 1] Training: 100%|██████████| 10438/10438 [02:32<00:00, 68.59samples/s]
[epoch 1] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 70.21samples/s]


[epoch 1] Train loss: 1.4104640818989596, Dev loss: 1.3088650743356731
[epoch 1] New best model saved at: saved_models/random_batch_size_10_epoch_1.pt


[epoch 2] Training: 100%|██████████| 10438/10438 [02:32<00:00, 68.38samples/s]
[epoch 2] Validation: 100%|██████████| 2088/2088 [00:28<00:00, 72.87samples/s]


[epoch 2] Train loss: 1.3340009189428497, Dev loss: 1.2389429191653238
[epoch 2] New best model saved at: saved_models/random_batch_size_10_epoch_2.pt


[epoch 3] Training: 100%|██████████| 10438/10438 [02:32<00:00, 68.60samples/s]
[epoch 3] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.02samples/s]


[epoch 3] Train loss: 1.2781330351628564, Dev loss: 1.2131268516111602
[epoch 3] New best model saved at: saved_models/random_batch_size_10_epoch_3.pt


[epoch 4] Training: 100%|██████████| 10438/10438 [02:31<00:00, 68.91samples/s]
[epoch 4] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.77samples/s]


[epoch 4] Train loss: 1.2419221013841502, Dev loss: 1.1866808565039384
[epoch 4] New best model saved at: saved_models/random_batch_size_10_epoch_4.pt


[epoch 5] Training: 100%|██████████| 10438/10438 [02:33<00:00, 68.00samples/s]
[epoch 5] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 71.69samples/s]


[epoch 5] Train loss: 1.2182571171458196, Dev loss: 1.1483177973322891
[epoch 5] New best model saved at: saved_models/random_batch_size_10_epoch_5.pt


[epoch 6] Training: 100%|██████████| 10438/10438 [02:31<00:00, 69.04samples/s]
[epoch 6] Validation: 100%|██████████| 2088/2088 [00:29<00:00, 70.80samples/s]


[epoch 6] Train loss: 1.2003826585766, Dev loss: 1.1089944497249913
[epoch 6] New best model saved at: saved_models/random_batch_size_10_epoch_6.pt


[epoch 7] Training:  12%|█▏        | 1260/10438 [00:19<02:19, 65.77samples/s]


KeyboardInterrupt: 

In [89]:
probe_loaded3 = TwoWordPSDProbe(random_BERT_dict["A000"].shape[0], 64)  # Replace ProbeModel with your actual model class
probe_loaded3.load_state_dict(torch.load('saved_models/random_batch_size_10_epoch_6.pt'))
probe_loaded3.eval()  # Set the model to evaluation mode

probe_regimen3 = ProbeRegimen(icd_tree)

# Generate predictions
predictions3, labels3 = probe_regimen3.predict(probe_loaded3, icd_codes_test, random_BERT_dict, batch_size=10)

correlations3, mean_correlations3 = calculate_spearman_correlations(predictions3, labels3)
print(f"Mean Spearman correlation: {mean_correlations3}")

Constructing TwoWordPSDProbe for a single large tree


[predicting batches]: 100%|██████████| 327/327 [00:45<00:00,  7.18it/s]

Mean Spearman correlation: 0.2709406532230638





In [90]:
mean_correlations = evaluate_full_correlations(icd_tree, random_BERT_dict, icd_codes_test, batch_size=10, shuffle_icd=False)
print(f"Mean Spearman correlation: {mean_correlations}")

Evaluating batches: 100%|██████████| 327/327 [00:47<00:00,  6.87it/s]


Mean Spearman correlation: 0.296835555182813
