In [1]:
import torch
import numpy as np
from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import pandas as pd
import pickle

from scripts.icd_tree import ICDTree
from scripts.data import create_batch_input_and_labels, get_squared_distances, calculate_spearman_correlations, evaluate_full_correlations, evaluate_correlation, locally_shuffle
from scripts.probe import TwoWordPSDProbeOriginal
from scripts.mc_search import train_weights_mc

In [8]:
# Load each dictionary
with open("base_BERT_dict.pkl", "rb") as file:
    base_BERT_dict = pickle.load(file)

with open("clinical_BERT_dict.pkl", "rb") as file:
    clinical_BERT_dict = pickle.load(file)

with open("random_BERT_dict.pkl", "rb") as file:
    random_BERT_dict = pickle.load(file)

# Initial Correlation Testing

In [72]:
file_path = "ICD10.csv"
df_tree = pd.read_csv(file_path, header=0)

df_codes = df_tree[df_tree['selectable'] == 'Y']
icd_codes = df_codes['coding'].tolist()

icd_tree = ICDTree(df_tree)

base_corr = evaluate_full_correlations(icd_tree, base_BERT_dict, icd_codes, batch_size = 4, window_size = None)
print(f'Base BERT model Spearman correlation: {base_corr}')

clinical_corr = evaluate_full_correlations(icd_tree, clinical_BERT_dict, icd_codes, batch_size = 4, window_size = None)
print(f'Clinical BERT model Spearman correlation: {clinical_corr}')

random_corr = evaluate_full_correlations(icd_tree, random_BERT_dict, icd_codes, batch_size = 4, window_size = None)
print(f'Random BERT model Spearman correlation: {random_corr}')

Evaluating batches: 100%|██████████| 4078/4078 [01:14<00:00, 54.64it/s]


Base BERT model Spearman correlation: 0.6446892730737822


Evaluating batches: 100%|██████████| 4078/4078 [01:18<00:00, 52.16it/s]


Clinical BERT model Spearman correlation: 0.6183534112640962


Evaluating batches: 100%|██████████| 4078/4078 [01:17<00:00, 52.37it/s]


Random BERT model Spearman correlation: 0.5901647156823182


# Depth Weight Adjustments

In [135]:
file_path = "ICD10.csv"
df_tree = pd.read_csv(file_path, header=0)

df_codes = df_tree[df_tree['selectable'] == 'Y']
icd_codes = df_codes['coding'].tolist()

icd_tree = ICDTree(df_tree)

icd_codes = locally_shuffle(icd_codes, 100)
icd_codes_train = icd_codes[:int(len(icd_codes)*0.8)]
icd_codes_test = icd_codes[int(len(icd_codes)*0.8):]

best_weights, train_correlations, test_correlations, max_diff = train_weights_mc(icd_tree, 
                                                                    base_BERT_dict, 
                                                                    icd_codes_train, 
                                                                    icd_codes_test,
                                                                    batch_size=10,
                                                                    max_epochs=1,
                                                                    search_radius=0.2,
                                                                    patience=25,
                                                                    initial_weights=[1, 1, 1, 1, 1],
                                                                    min_weight=0.05,
                                                                    num_samples=10,
                                                                    num_batches_to_average=20)

Benchmark Correlation on Test Set with Initial Weights [1. 1. 1. 1. 1.]:


Evaluating batches: 100%|██████████| 327/327 [00:47<00:00,  6.90it/s]


Test set correlation: 0.44423603047947635

-----------------------------------------------------------------
No Improvement Count --> 0

Epoch 1, Batch Group 1 to 20:
  Before update: weights = [1. 1. 1. 1. 1.], average correlation = 0.5064704414341519
Best correlation in group: 0.5066952911586535, Best weights: [0.81409428 0.89711459 1.07936857 0.86034168 0.95873829]
Largest difference in average correlation amongst weights groups: 0.0009851441748405243
------------ Batch performance did not improve --------------
Test correlation did not improve

-----------------------------------------------------------------
No Improvement Count --> 1

Epoch 1, Batch Group 21 to 40:
  Before update: weights = [1. 1. 1. 1. 1.], average correlation = 0.42594884715955084
Best correlation in group: 0.4279721488279117, Best weights: [0.93655336 0.91218762 0.831785   1.02795174 0.87211452]
Largest difference in average correlation amongst weights groups: 0.004627168631542833
Batch performance improved w

Evaluating batches: 100%|██████████| 327/327 [00:46<00:00,  7.00it/s]


Test set spearman correlation improved from 0.44423603047947635 to 0.4447759381186618

-----------------------------------------------------------------
No Improvement Count --> 0

Epoch 1, Batch Group 41 to 60:
  Before update: weights = [0.93655336 0.91218762 0.831785   1.02795174 0.87211452], average correlation = 0.4800988843808409
Best correlation in group: 0.48145043273437793, Best weights: [0.99892083 0.76917229 0.96234256 0.99935608 0.79410126]
Largest difference in average correlation amongst weights groups: 0.0013515483535370132
Batch performance improved with weights: [0.99892083 0.76917229 0.96234256 0.99935608 0.79410126]"
Changing weights...
Evaluating on the test set...


Evaluating batches: 100%|██████████| 327/327 [00:46<00:00,  7.03it/s]


Test correlation did not improve

-----------------------------------------------------------------
No Improvement Count --> 1

Epoch 1, Batch Group 61 to 80:
  Before update: weights = [0.99892083 0.76917229 0.96234256 0.99935608 0.79410126], average correlation = 0.4824464610694796
Best correlation in group: 0.4850166471559416, Best weights: [1.12534118 0.85344232 0.88999931 1.17562491 0.7169454 ]
Largest difference in average correlation amongst weights groups: 0.002570186086462034
Batch performance improved with weights: [1.12534118 0.85344232 0.88999931 1.17562491 0.7169454 ]"
Changing weights...
Evaluating on the test set...


Evaluating batches: 100%|██████████| 327/327 [00:46<00:00,  7.10it/s]


Test set spearman correlation improved from 0.4447759381186618 to 0.4453728962085087

-----------------------------------------------------------------
No Improvement Count --> 0

Epoch 1, Batch Group 81 to 100:
  Before update: weights = [1.12534118 0.85344232 0.88999931 1.17562491 0.7169454 ], average correlation = 0.47663811018133434
Best correlation in group: 0.47663811018133434, Best weights: [1.23669347 0.78244739 0.74632947 1.11616384 0.74976216]
Largest difference in average correlation amongst weights groups: 0.0
------------ Batch performance did not improve --------------
Test correlation did not improve

-----------------------------------------------------------------
No Improvement Count --> 1

Epoch 1, Batch Group 101 to 120:
  Before update: weights = [1.12534118 0.85344232 0.88999931 1.17562491 0.7169454 ], average correlation = 0.4354983618700781
Best correlation in group: 0.4354983618700781, Best weights: [0.95601634 0.71975292 0.80660976 1.20251364 0.59808663]
Large

Evaluating batches: 100%|██████████| 327/327 [00:46<00:00,  7.06it/s]


Test correlation did not improve

-----------------------------------------------------------------
No Improvement Count --> 4

Epoch 1, Batch Group 161 to 180:
  Before update: weights = [0.97250136 0.87772913 1.02915484 1.08754474 0.88636447], average correlation = 0.3852124540207381
Best correlation in group: 0.389156870978272, Best weights: [0.93415143 1.07409878 0.87710433 1.07038336 0.71817548]
Largest difference in average correlation amongst weights groups: 0.005931121226384672
Batch performance improved with weights: [0.93415143 1.07409878 0.87710433 1.07038336 0.71817548]"
Changing weights...
Evaluating on the test set...


Evaluating batches: 100%|██████████| 327/327 [00:47<00:00,  6.90it/s]


Test correlation did not improve

-----------------------------------------------------------------
No Improvement Count --> 5

Epoch 1, Batch Group 181 to 200:
  Before update: weights = [0.93415143 1.07409878 0.87710433 1.07038336 0.71817548], average correlation = 0.386278572525275
Best correlation in group: 0.386278572525275, Best weights: [1.10998786 1.24705996 0.89338376 1.04750454 0.65866028]
Largest difference in average correlation amongst weights groups: 0.0
------------ Batch performance did not improve --------------
Test correlation did not improve

-----------------------------------------------------------------
No Improvement Count --> 6

Epoch 1, Batch Group 201 to 220:
  Before update: weights = [0.93415143 1.07409878 0.87710433 1.07038336 0.71817548], average correlation = 0.4495400827470628
Best correlation in group: 0.44958198038164604, Best weights: [0.88368433 1.03924905 0.98558076 1.17949447 0.60472918]
Largest difference in average correlation amongst weights g

Evaluating batches: 100%|██████████| 327/327 [00:46<00:00,  6.97it/s]


Test correlation did not improve

-----------------------------------------------------------------
No Improvement Count --> 18

Epoch 1, Batch Group 441 to 460:
  Before update: weights = [1.12479487 0.99802966 0.92941789 0.96376214 0.53176536], average correlation = 0.5292529708032654
Best correlation in group: 0.5292529708032654, Best weights: [1.08040297 0.97487261 0.90059658 1.07912215 0.63323877]
Largest difference in average correlation amongst weights groups: 0.0
------------ Batch performance did not improve --------------
Test correlation did not improve

-----------------------------------------------------------------
No Improvement Count --> 19

Epoch 1, Batch Group 461 to 480:
  Before update: weights = [1.12479487 0.99802966 0.92941789 0.96376214 0.53176536], average correlation = 0.4254463004547041
Best correlation in group: 0.4254463004547041, Best weights: [1.2002789  1.15846577 1.09487031 1.14664031 0.63489794]
Largest difference in average correlation amongst weight

# Now compare the correlation of the optimized weights with normal weights on the test set:

In [136]:
test_correlations

[0.44423603047947635,
 0.44423603047947635,
 0.4447759381186618,
 0.44379902980726527,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.44379902980726527,
 0.44498680654028466,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.44498680654028466,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087,
 0.4453728962085087]