In [14]:
import pandas as pd
import numpy as np
import re
import tqdm
import pickle

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import GPT2Model
from utils import *

In [15]:
from icd9cms.icd9 import search

with open("./data/pcode_dict.txt", "rb") as fp: 
    icd9_pcode_dict = pickle.load(fp)

In [2]:
NTDBGPT2_lm = AutoModelForCausalLM.from_pretrained('dracoglacius/NTDB-GPT2')
NTDBGPT2_tokenizer = AutoTokenizer.from_pretrained('dracoglacius/NTDB-GPT2')
NTDBGPT2_embed = GPT2Model.from_pretrained('dracoglacius/NTDB-GPT2')

Some weights of the model checkpoint at dracoglacius/NTDB-GPT2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## ECodes

* E812.0 = Other motor vehicle traffic accident involving collision with motor vehicle injuring driver of motor vehicle other than motorcycle.
* E885.9 = Accidental fall from other slipping tripping or stumbling
* E966.0 = Assault by cutting and piercing instrument

In [46]:
trn_seq = np.load("./data/25k_train_seqs_3_22_E8859_E8120_E9660_E9654_E9240.npy")

In [47]:
trn_seq

array(['<START> E885.9 <DSTART> 853.00 873.0 <PSTART> 01.39 <END>',
       '<START> E885.9 <DSTART> 820.21 812.01 <PSTART> 79.01 00.33 79.35 <END>',
       '<START> E885.9 <DSTART> 831.01 812.09 <PSTART> <UNK> <END>', ...,
       '<START> E924.0 <DSTART> 948.00 <PSTART> 86.28 <END>',
       '<START> E924.0 <DSTART> 948.00 945.26 842.13 <PSTART>  <END>',
       '<START> E924.0 <DSTART> 948 944.37 944.36 944.32 944.31 943.31 <PSTART>  <END>'],
      dtype='<U622')

## Separate Data

#### Training Data is In Domain Data

In [48]:
e8120_trn_seq = [x for x in trn_seq if 'E812.0' in x] # 5000 items
e8859_trn_seq = [x for x in trn_seq if 'E885.9' in x] # 5000 items
e9660_trn_seq = [x for x in trn_seq if 'E966.0' in x] # 5000 items
e9654_trn_seq = [x for x in trn_seq if 'E965.4' in x] # 5000 items
e9240_trn_seq = [x for x in trn_seq if 'E924.0' in x] # 5000 items

#### From the Training Data Create Adversarial Examples

In [68]:
def print_seq_dsc(seq):
    cds = seq.split()
    tp = 'START'
    for c in cds:
        if c == '<START>':
            print('=' * 9 + ' START ' + '=' * 9)
        elif c == '<DSTART>':
            tp = 'DX'
            print('=' * 10 + ' DXS ' + '=' * 10)
        elif c == '<PSTART>':
            tp = 'PR'
            print('=' * 10 + ' PRS ' + '=' * 10)
        elif c == '<END>':
            print('=' * 10 + ' END ' + '=' * 10)
        elif c == '<UNK>':
            print(f'{c}:Unknown Code')
        else:
            if tp == 'DX':
                print(search(c))
            if tp == 'PR':
                pr_cd = re.sub(r'\.', '', c)
                if pr_cd in icd9_pcode_dict:
                    print(f"{pr_cd}:{icd9_pcode_dict[pr_cd]}")
                else:
                    print(f'{pr_cd}:Unknown Code')

In [69]:
def print_seqs(seqs):
    for seq in seqs:
        print_seq_dsc(seq)
        print()

In [70]:
print_seqs(e8120_trn_seq[:100])

8600:Traum pneumothorax-close:Traumatic pneumothorax without mention of open wound into thorax
80702:Fracture two ribs-closed:Closed fracture of two ribs
9352:Application of neck support
8801:Computerized axial tomography of abdomen
8741:Computerized axial tomography of thorax

9190:Abrasion NEC:Abrasion or friction burn of other, multiple, and unspecified sites, without mention of infection
8798:Open wound site NOS:Open wound(s) (multiple) of unspecified site(s), without mention of complication
8509:Concussion NOS:Concussion, unspecified
82525:Fx metatarsal-closed:Closed fracture of metatarsal bone(s)

79311:Solitary pulmonry nodule:Solitary pulmonary nodule
83304:Disloc carpometacarp-cl:Closed dislocation of carpometacarpal (joint)
7856:Enlargement lymph nodes:Enlargement of lymph nodes
784:Symptoms involving head and neck:None
7339:Other and unspecified disorders of bone and cartilage:None
733:Other disorders of bone and cartilage:None
721:Spondylosis and allied disorders:None
428:H

In [24]:
get_seq_dsc(e8120_trn_seq[0])

8600:Traum pneumothorax-close:Traumatic pneumothorax without mention of open wound into thorax
80702:Fracture two ribs-closed:Closed fracture of two ribs
9352:Application of neck support
8801:Computerized axial tomography of abdomen
8741:Computerized axial tomography of thorax


In [10]:
search('E812.0')

E8120:Mv collision NOS-driver:Other motor vehicle traffic accident involving collision with motor vehicle injuring driver of motor vehicle other than motorcycle

We have the intuition that given a distribution with center of the hyper-elliposid $c$ and the shape of the ellipsoid defined by $\Sigma$, $c$ and $\sigma$ should not deviate from the empirical mean ($\hat{c}$) and the covariance estimations ($\hat{\Sigma}$) taken from the training data. 

To obtain these estimates we need to:

1. Feed the NTDB model with the training data (length $n$) and from the last token get the features of each layer ($n$ x 13 x 768)
1. Calculate the sample mean ($\hat{c}$) and the covariance estimate ($\hat{\Sigma}$), while also getting the estimated pseudo-inverse (called `.precision_` in sklearn)

To obtain the OOD estimation we need to:

1. Calculate the Mahalanobis Distance Feature (MDF) using the generated data (should be a vector equal to the number of layers)
1. Calculate the Anomaly Score
  1. This is a one-class SVM with a linear kernel with MDF as features

In [7]:
def get_hidden_embeddings(hidden_states, use_last=True):
    _start = torch.squeeze(torch.stack(hidden_states[0]).transpose(0,2), dim=1)
    _hs = torch.stack([torch.reshape(torch.stack(x), [13, 768]) for x in hidden_states[1:]])
    
    _em = torch.concat([_start, _hs])
    if use_last:
        return _em[-1, :, :]
    else:
        return _em

In [8]:
def get_embeddings(sequences, use_last=True):
    last_token_layer_embeddings = []
    for seq in tqdm.tqdm(sequences):
        out = NTDBGPT2_lm.generate(
            NTDBGPT2_tokenizer.encode(seq, return_tensors='pt'),
            do_sample=True,
            #min_length=10,
            #max_length=12,
            #top_p=0.9, 
            top_k=0,
            return_dict_in_generate=True,
            forced_eos_token_id=NTDBGPT2_tokenizer.eos_token_id,
            #repetition_penalty=3.0,
            #length_penalty=1.0,
            #num_return_seqs=1,
            output_hidden_states=True
        )
        last_token_layer_embeddings.append(get_hidden_embeddings(out.hidden_states, use_last))
    if use_last:
        return torch.stack(last_token_layer_embeddings)
    else:
        return last_token_layer_embeddings

#### Get Sequence Embeddings of All Layers

In [9]:
#e8120_trn_all_token_layer_embeddings = get_embeddings(e8120_trn_seq, use_last=False)
#e8120_gen_all_token_layer_embeddings = get_embeddings(e8120_gen_seq, use_last=False)
#np.save("./outputs/e8120_trn_all_em.npy", e8120_trn_all_token_layer_embeddings)
#np.save("./outputs/e8120_gen_all_em.npy", e8120_gen_all_token_layer_embeddings)

#e8120_trn_end_token_layer_embeddings = torch.stack([x[-1,:,:] for x in e8120_trn_all_token_layer_embeddings])
#np.save("./outputs/e8120_trn_end_em.npy", e8120_trn_end_token_layer_embeddings)
#e8120_gen_end_token_layer_embeddings = torch.stack([x[-1,:,:] for x in e8120_gen_all_token_layer_embeddings])
#np.save("./outputs/e8120_gen_end_em.npy", e8120_gen_end_token_layer_embeddings)

In [10]:
#e8859_trn_all_token_layer_embeddings = get_embeddings(e8859_trn_seq, use_last=False)
#e8859_gen_all_token_layer_embeddings = get_embeddings(e8859_gen_seq, use_last=False)
#np.save("./outputs/e8859_trn_all_em.npy", e8859_trn_all_token_layer_embeddings)
#np.save("./outputs/e8859_gen_all_em.npy", e8859_gen_all_token_layer_embeddings)

#e8859_trn_end_token_layer_embeddings = torch.stack([x[-1,:,:] for x in e8859_trn_all_token_layer_embeddings])
#np.save("./outputs/e8859_trn_end_em.npy", e8859_trn_end_token_layer_embeddings)

#e8859_gen_end_token_layer_embeddings = torch.stack([x[-1,:,:] for x in e8859_gen_all_token_layer_embeddings])
#np.save("./outputs/e8859_gen_end_em.npy", e8859_gen_end_token_layer_embeddings)

In [11]:
#e9660_trn_all_token_layer_embeddings = get_embeddings(e9660_trn_seq, use_last=False)
#e9660_gen_all_token_layer_embeddings = get_embeddings(e9660_gen_seq, use_last=False)
#np.save("./outputs/e9660_trn_all_em.npy", e9660_trn_all_token_layer_embeddings)
#np.save("./outputs/e9660_gen_all_em.npy", e9660_gen_all_token_layer_embeddings)

#e9660_trn_end_token_layer_embeddings = torch.stack([x[-1,:,:] for x in e9660_trn_all_token_layer_embeddings])
#np.save("./outputs/e9660_trn_end_em.npy", e9660_trn_end_token_layer_embeddings)

#e9660_gen_end_token_layer_embeddings = torch.stack([x[-1,:,:] for x in e9660_gen_all_token_layer_embeddings])
#np.save("./outputs/e9660_gen_end_em.npy", e9660_gen_end_token_layer_embeddings)

100%|███████████████████████████████████████████████████████████| 5000/5000 [13:52<00:00,  6.01it/s]
100%|███████████████████████████████████████████████████████████| 5000/5000 [12:36<00:00,  6.61it/s]
  arr = np.asanyarray(arr)
  arr = np.asanyarray(arr)
