In [None]:
!pip install gspread --upgrade -q

In [None]:
import pandas as pd
import os
from collections import defaultdict

pd.set_option('display.max_colwidth', None)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
DATA_HOME = '/content/gdrive/Shareddrives/PROJECT_ROOT_DIR'
OHDSI_VOCAB_HOME = os.path.join(DATA_HOME, 'ohdsi-vocab')
DS_HOME =  os.path.join(DATA_HOME, 'injury-icd-dataset')
# Optional: pretrain BERT on 5-character ICD-10 codes, then fine-tune on 4-characters
DS_HOME_PRETRAIN = os.path.join(DATA_HOME, 'pretrain', 'injury-icd-dataset')
PLOT_HOME = os.path.join(DATA_HOME, "publishing", "figures")

In [None]:
import gspread
from oauth2client.client import GoogleCredentials
from google.colab import auth
from google.auth import default
auth.authenticate_user()
creds, _ = default()
gc = gspread.authorize(creds)
# An URL to the spreadsheet with patient data. Must include columns: ['patient_id', 'tertiary_imaging_report', 'tertiary_impression', 'ICD 10']
# See example here: https://docs.google.com/spreadsheets/d/19PKbWvzFohSQhzaMaz9lvfDuOqMZI8ZJM7aqzZ57Xeg/edit?usp=sharing
SPRSHEET_URL = ""
sprsheet = gc.open_by_url(SPRSHEET_URL)
ws = sprsheet.worksheet('data')

In [None]:
raw_data = pd.DataFrame(ws.get_all_records())
raw_data.rename(columns={
    'ICD 10': 'icd_code',
    'ICD10 Description': 'icd_name',
    'Diagnoses Region': 'diagnoses_region',
    'AIS  Code': 'ais_code',
    'AIS dESC': 'ais_name',
    'Tertiary_exam': 'tertiary_exam',
    'Tertiary_imaging_report': 'tertiary_imaging_report',
    'Tertiary_impression': 'tertiary_impression'
}, inplace=True)
raw_data.patient_id = raw_data.patient_id.astype(str)
# assert raw_data[raw_data.tertiary_exam != ''].shape == raw_data[raw_data.tertiary_impression != ''].shape
# assert raw_data[raw_data.tertiary_exam != ''].shape == raw_data[raw_data.tertiary_imaging_report != ''].shape

In [None]:
raw_data[(raw_data.tertiary_impression == '')&(raw_data.tertiary_exam != '')].shape

In [None]:
raw_data[(raw_data.tertiary_impression != '')&(raw_data.tertiary_exam == '')].shape

In [None]:
cases = raw_data[(raw_data.tertiary_impression != '') | (raw_data.tertiary_exam != '')][['patient_id', 'tertiary_exam', 'tertiary_imaging_report', 'tertiary_impression']].copy()
#assert not any(cases.duplicated('patient_id'))

In [None]:
cases['total_text_len'] = cases.apply(lambda row: len(str(row.tertiary_exam)) + len(str(row.tertiary_imaging_report)) + len(str(row.tertiary_impression)), axis=1)

In [None]:
cases.shape

In [None]:
duplicated_patient_ids = cases[cases.duplicated('patient_id', keep=False)].sort_values('patient_id').drop_duplicates('patient_id').patient_id.to_list()

print('Duplicated patient_ids', duplicated_patient_ids)

In [None]:
len(duplicated_patient_ids)

In [None]:
cases = cases[~cases.patient_id.isin(duplicated_patient_ids)].copy()

In [None]:
cases.shape

In [None]:
cases[cases.tertiary_impression.apply(len) < 2]

In [None]:
assert not any(cases.total_text_len < 30)

In [None]:
case_icd_codes = raw_data[['patient_id', 'icd_code', 'icd_name', 'diagnoses_region', 'AIS Code', 'ais_name']]

In [None]:
case_icd_codes = case_icd_codes[~case_icd_codes.icd_code.isin(['NA', ''])]

In [None]:
cases.shape

In [None]:
case_icd_codes[~case_icd_codes.patient_id.isin(cases.patient_id)].drop_duplicates('patient_id')

In [None]:
case_icd_codes = case_icd_codes[case_icd_codes.patient_id.isin(cases.patient_id)]

In [None]:
icd_vocab = pd.read_csv(os.path.join(OHDSI_VOCAB_HOME, 'ICD10_CONCEPT.csv.gz'), sep='\t', 
                       dtype={'standard_concept': str, 'concept_code': str, 'invalid_reason': str})

In [None]:
assert case_icd_codes[~case_icd_codes.icd_code.isin(icd_vocab.concept_code)].shape[0] == 0

In [None]:
case_icd_codes[~case_icd_codes.icd_code.str.startswith('S')].shape

In [None]:
cases[~cases.patient_id.isin(case_icd_codes.patient_id)].shape[0]

In [None]:
case_icd_codes = case_icd_codes.merge(icd_vocab.rename(columns={'concept_code': 'icd_code'}), how='left')
case_icd_codes = case_icd_codes[case_icd_codes.icd_code.str.startswith('S')]

In [None]:
print('Deprecated ICD10 codes')

case_icd_codes[case_icd_codes.valid_end_date != 20991231]

In [None]:
case_icd_codes.groupby('concept_class_id').size()

In [None]:
cases = cases[cases.patient_id.isin(case_icd_codes.patient_id)]

In [None]:
# Shrink ICD codes to 5 char icd codes. Some 5 char icd codes do not exist, so we shring them to 4 char icd codes
case_icd_codes_5char = case_icd_codes[['patient_id', 'icd_code']].copy()
case_icd_codes_5char.icd_code = case_icd_codes_5char.icd_code.apply(lambda x: x[:6])
invalid = case_icd_codes_5char[~case_icd_codes_5char.icd_code.isin(icd_vocab.concept_code)].copy()
case_icd_codes_5char = case_icd_codes_5char[case_icd_codes_5char.icd_code.isin(icd_vocab.concept_code)]
invalid.icd_code = invalid.icd_code.apply(lambda x: x[:5])
assert all(invalid.icd_code.isin(icd_vocab.concept_code))
case_icd_codes_5char = pd.concat([case_icd_codes_5char, invalid])
case_icd_codes_5char.sort_values('patient_id', inplace=True)
case_icd_codes_5char.drop_duplicates(inplace=True)
case_icd_codes_5char = case_icd_codes_5char.merge(icd_vocab.rename(columns={'concept_code': 'icd_code'}), how='left')
assert not any(case_icd_codes_5char.concept_id.isnull())
assert all(case_icd_codes_5char.valid_end_date == 20991231)

In [None]:
case_icd_codes_5char.groupby('concept_class_id').size()

In [None]:
# Shrink ICD codes to 5 char icd codes. Some 5 char icd codes do not exist, so we shring them to 4 char icd codes

case_icd_codes_4char = case_icd_codes[['patient_id', 'icd_code']].copy()
case_icd_codes_4char.icd_code = case_icd_codes_4char.icd_code.apply(lambda x: x[:5])
assert all(case_icd_codes_4char.icd_code.isin(icd_vocab.concept_code))

case_icd_codes_4char.sort_values('patient_id', inplace=True)
case_icd_codes_4char.drop_duplicates(inplace=True)
case_icd_codes_4char = case_icd_codes_4char.merge(icd_vocab.rename(columns={'concept_code': 'icd_code'}), how='left')
assert not any(case_icd_codes_4char.concept_id.isnull())
assert all(case_icd_codes_4char.valid_end_date == 20991231)

In [None]:
print('Number of cases', cases.shape[0])
print('Number of icd codes', case_icd_codes.shape[0])
print('Number of 5char icd codes', case_icd_codes_5char.shape[0])
print('Number of 4char icd codes', case_icd_codes_4char.shape[0])

In [None]:
cases.to_csv(os.path.join(DS_HOME, 'case.csv'), index=False)
case_icd_codes.to_csv(os.path.join(DS_HOME, 'case-icd-code.csv'), index=False)
case_icd_codes_5char.to_csv(os.path.join(DS_HOME, 'case-icd-code-5-char.csv'), index=False)
case_icd_codes_4char.to_csv(os.path.join(DS_HOME, 'case-icd-code-4-char.csv'), index=False)

## All 5-char icd codes

In [None]:
valid_icd_s = icd_vocab[(icd_vocab.valid_end_date == 20991231) & icd_vocab.concept_code.str.startswith('S')].copy()
assert not any(valid_icd_s.concept_class_id.isnull())

In [None]:
valid_icd_s.groupby('concept_class_id').size()

In [None]:
icd_vocab.concept_code.unique()

In [None]:
valid_icd_s_5_char = valid_icd_s[valid_icd_s.concept_class_id == '5-char nonbill code']
tmp = valid_icd_s[valid_icd_s.concept_class_id.isin(['6-char nonbill code', '7-char billing code'])][['concept_code']].copy()
tmp.concept_code = tmp.concept_code.apply(lambda x: x[:6])
tmp = tmp[~tmp.concept_code.isin(valid_icd_s_5_char.concept_code)]
tmp.concept_code = tmp.concept_code.apply(lambda x: x[:5])
assert all(tmp.concept_code.isin(valid_icd_s.concept_code))

valid_icd_s_5_char = pd.concat([valid_icd_s_5_char, valid_icd_s[valid_icd_s.concept_code.isin(tmp.concept_code)]]).copy()
assert not any(valid_icd_s_5_char.duplicated('concept_code'))
assert all(valid_icd_s_5_char.vocabulary_id == 'ICD10CM')
valid_icd_s_5_char = valid_icd_s_5_char[['concept_id', 'concept_name', 'concept_class_id', 'concept_code']].copy().sort_values('concept_code')

n_cases_by_5_char_code = case_icd_codes_5char.groupby('icd_code', as_index=False).agg({'patient_id': 'count'}).rename(columns={'patient_id': 'n_cases'})
valid_icd_s_5_char = valid_icd_s_5_char.rename(columns={'concept_code': 'icd_code'}).merge(n_cases_by_5_char_code, how='left')
valid_icd_s_5_char.n_cases = valid_icd_s_5_char.n_cases.fillna(0).astype(int)

In [None]:
valid_icd_s_5_char[['concept_id', 'icd_code', 'concept_name', 'concept_class_id', 'n_cases']].to_csv(os.path.join(DS_HOME, 'icd10-5-char-vocab.csv'), index=False)

## All 4-char icd codes

In [None]:
valid_icd_s_4_char = valid_icd_s[valid_icd_s.concept_class_id == '4-char nonbill code']
tmp = valid_icd_s[valid_icd_s.concept_class_id.isin(['6-char nonbill code', '7-char billing code'])][['concept_code']].copy()
tmp.concept_code = tmp.concept_code.apply(lambda x: x[:5])
assert all(tmp.concept_code.isin(valid_icd_s_4_char.concept_code))

assert not any(valid_icd_s_4_char.duplicated('concept_code'))
assert all(valid_icd_s_4_char.vocabulary_id == 'ICD10CM')
valid_icd_s_4_char = valid_icd_s_4_char[['concept_id', 'concept_name', 'concept_class_id', 'concept_code']].copy().sort_values('concept_code')

n_cases_by_4_char_code = case_icd_codes_4char.groupby('icd_code', as_index=False).agg({'patient_id': 'count'}).rename(columns={'patient_id': 'n_cases'})
valid_icd_s_4_char = valid_icd_s_4_char.rename(columns={'concept_code': 'icd_code'}).merge(n_cases_by_4_char_code, how='left')
valid_icd_s_4_char.n_cases = valid_icd_s_4_char.n_cases.fillna(0).astype(int)

In [None]:
valid_icd_s_4_char[['concept_id', 'icd_code', 'concept_name', 'concept_class_id', 'n_cases']].to_csv(os.path.join(DS_HOME, 'icd10-4-char-vocab.csv'), index=False)

## Prepare set of labels and split dataset

In [None]:
labels = valid_icd_s_4_char[valid_icd_s_4_char.n_cases > 5].icd_code.to_list()
with open(os.path.join(DS_HOME, 'label.txt'), 'w') as f:
  f.write('\n'.join(labels))

In [None]:
with open(os.path.join(DS_HOME, 'label.txt')) as f:
  labels = f.read().split('\n')
          

In [None]:
len(labels)

In [None]:
case_labels = case_icd_codes_4char[case_icd_codes_4char.icd_code.isin(labels)][['patient_id', 'icd_code', 'concept_name']]
case_labels = case_labels.groupby('patient_id', as_index=False).agg({'icd_code': lambda x: ','.join(sorted(list(set(x))))})
case_labels.rename(columns={'icd_code': 'label', 'concept_name': 'label_name'}, inplace=True)

In [None]:
case_labels_all = case_icd_codes[['patient_id', 'icd_code', 'concept_name']]
case_labels_all.rename(columns={'icd_code': 'label', 'concept_name': 'label_name'}, inplace=True)

In [None]:
case_labels.to_csv(os.path.join(DS_HOME, 'case-labels.csv'), index=False)

In [None]:
patient_ids = cases.patient_id.sample(frac=1).to_list()
validation = patient_ids[:int(len(patient_ids)*0.15)]
test = patient_ids[int(len(patient_ids)*0.15):int(len(patient_ids)*0.30)]
train = patient_ids[int(len(patient_ids)*0.30):]
with open(os.path.join(DS_HOME, 'validation.txt'), 'w') as f:
  f.write('\n'.join([str(x) for x in validation]))

with open(os.path.join(DS_HOME, 'test.txt'), 'w') as f:
  f.write('\n'.join([str(x) for x in test]))

with open(os.path.join(DS_HOME, 'train.txt'), 'w') as f:
  f.write('\n'.join([str(x) for x in train]))

with open(os.path.join(DS_HOME, 'train_and_validation.txt'), 'w') as f:
  f.write('\n'.join([str(x) for x in train]+[str(x) for x in validation]))

In [None]:
with open(os.path.join(DS_HOME, 'validation.txt')) as f:
  validation = f.read().split('\n')

with open(os.path.join(DS_HOME, 'test.txt')) as f:
  test = f.read().split('\n')

with open(os.path.join(DS_HOME, 'train.txt')) as f:
  train = f.read().split('\n')

In [None]:
for x in validation:
  assert x not in train
  assert x not in test

for x in train:
  assert x not in validation
  assert x not in test

for x in test:
  assert x not in validation
  assert x not in train

In [None]:
len(validation)

In [None]:
len(train)

# Exploratory Data Analysis

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def save_plot(plot_name):
  plt.savefig(os.path.join(PLOT_HOME, plot_name), dpi=1000)
  print("Saved plot to", os.path.join(PLOT_HOME, plot_name))

In [None]:
patient_to_codes = defaultdict(set)
for i, row in case_labels.iterrows():
  pid = row["patient_id"]
  label = row["label"]
  patient_to_codes[pid].add(label)

patient_to_count = {pid: len(labels) for pid, labels in patient_to_codes.items()}
counts = list(patient_to_count.values())
counts_cumu = np.array(counts).cumsum()
counts_bins = np.arange(0, max(counts) + 1, 1)

In [None]:
fig, ax = plt.subplots(figsize=(30, 15))
# ax2 = ax.twinx()
n, bins, patches = ax.hist(counts, bins=counts_bins, density=False, alpha=0.5)
# n, bins, patches = ax2.hist(counts, cumulative=1, histtype='step', bins=counts_bins, color='tab:orange', alpha=0.5)
ax.set_xlim((ax.get_xlim()[0], max(counts)))
ax.grid(True, ls="--")
ax.set_ylabel("Number of trauma tertiary surveys", fontsize=40)
ax.set_xlabel("Number of injury ICD-10 diagnosis codes extracted by trauma registrars per admission encounter", fontsize=40)
ax.tick_params(axis='both', which='major', labelsize=30)
ax.tick_params(axis='both', which='minor', labelsize=30)
save_plot("code_counts_high_res.png")

In [None]:
# Tot number of annotations, Tot annotations used, Percentage
len(case_icd_codes), sum(counts), sum(counts)/len(case_icd_codes)

In [None]:
cases = pd.read_csv(os.path.join(DS_HOME, 'case.csv'))

In [None]:
case_labels[~case_labels.label_name.str.contains("uperficial")]

In [None]:
print("Mean and SD of 4-character non-superficial ICD-10 codes per patient")

for name, temp in zip(["4-char non_sup", "all"], [case_labels, case_labels_all]):
  temp = temp[~temp.label_name.str.contains('uperficial')]
  print(name, end="  ")
  mean, sd = temp.groupby("patient_id").count()["label"].mean(), temp.groupby("patient_id").count()["label"].std()
  mean, sd = round(mean, 3), round(sd, 3)
  print(mean, sd)

## Split labels

In [None]:
with open(os.path.join(DS_HOME, 'label.txt')) as f:
  labels = f.read().split('\n')
labels = pd.DataFrame(labels)
labels.columns = ['label']

In [None]:
case_labels = pd.read_csv(os.path.join(DS_HOME, 'case-labels.csv'))
case_labels_5char = pd.read_csv(os.path.join(DS_HOME, 'case-labels-5-char.csv'))

In [None]:
label_counts = case_labels.groupby('label', as_index=False).agg({'patient_id': 'count', 'label_name': 'first'})
label_counts.columns = ['label', 'n_cases', 'label_name']
label_counts.sort_values('n_cases', ascending=False, inplace=True)


In [None]:
for top in [10, 50, 170]:
  print("Top", top, end=" ")
  codes = label_counts[~label_counts.label_name.str.contains('uperficial')][:top+1]
  temp = case_labels[case_labels.label.isin(codes.label)]
  print(round(len(temp) / case_labels.patient_id.unique().shape[0], 3))

In [None]:
label_counts_5char = case_labels_5char.groupby('label', as_index=False).agg({'patient_id': 'count', 'label_name': 'first'})
label_counts_5char.columns = ['label', 'n_cases', 'label_name']
label_counts_5char.sort_values('n_cases', ascending=False, inplace=True)
label_counts_5char = label_counts_5char[label_counts_5char.n_cases > 5]
len(label_counts_5char)

In [None]:
with open(os.path.join(DS_HOME, 'label-non-superficial.txt'), 'w') as f:
  f.write('\n'.join(label_counts[~label_counts.label_name.str.contains('uperficial')].label))

with open(os.path.join(DS_HOME, 'label-non-superficial-top10.txt'), 'w') as f:
  f.write('\n'.join(label_counts[~label_counts.label_name.str.contains('uperficial')].label[:10]))

with open(os.path.join(DS_HOME, 'label-non-superficial-top50.txt'), 'w') as f:
  f.write('\n'.join(label_counts[~label_counts.label_name.str.contains('uperficial')].label[:50]))

In [None]:
label_counts_5char = label_counts_5char[label_counts_5char.n_cases > 5]
with open(os.path.join(DS_HOME_PRETRAIN, 'label-non-superficial-5-char.txt'), 'w') as f:
  f.write('\n'.join(label_counts_5char[~label_counts_5char.label_name.str.contains('uperficial')].label))

with open(os.path.join(DS_HOME_PRETRAIN, 'label-non-superficial-5-char-top10.txt'), 'w') as f:
  f.write('\n'.join(label_counts_5char[~label_counts_5char.label_name.str.contains('uperficial')].label[:10]))

with open(os.path.join(DS_HOME_PRETRAIN, 'label-non-superficial-5-char-top50.txt'), 'w') as f:
  f.write('\n'.join(label_counts_5char[~label_counts_5char.label_name.str.contains('uperficial')].label[:50]))

In [None]:
label_counts.to_csv(os.path.join(DS_HOME, 'label-case-count.csv'), index=False)
label_counts_5char.to_csv(os.path.join(DS_HOME, 'label-case-count-5char.csv'), index=False)