In [1]:
import pandas as pd
import numpy as np
import gc
import re
import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import GPT2Model
from utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from icd9cms.icd9 import search
import pickle
with open("./data/pcode_dict.txt", "rb") as fp: 
    icd9_pcode_dict = pickle.load(fp)

In [27]:
def print_seq_dsc(seq):
    cds = seq.split()
    tp = 'START'
    for c in cds:
        if c == '<START>':
            print('=' * 9 + ' START ' + '=' * 9)
        elif c == '<DSTART>':
            tp = 'DX'
            print('=' * 10 + ' DXS ' + '=' * 10)
        elif c == '<PSTART>':
            tp = 'PR'
            print('=' * 10 + ' PRS ' + '=' * 10)
        elif c == '<END>':
            print('=' * 10 + ' END ' + '=' * 10)
        elif c == '<UNK>':
            print(f'{c}:Unknown Code')
        else:
            if tp == 'DX':
                d = search(c)
                if d:
                    print(d)
            if tp == 'PR':
                pr_cd = re.sub(r'\.', '', c)
                if pr_cd in icd9_pcode_dict:
                    print(f"{pr_cd}:{icd9_pcode_dict[pr_cd]}")
                else:
                    print(f'{pr_cd}:Unknown Code')

In [4]:
NTDBGPT2_lm = AutoModelForCausalLM.from_pretrained('dracoglacius/NTDB-GPT2')
NTDBGPT2_tokenizer = AutoTokenizer.from_pretrained('dracoglacius/NTDB-GPT2')
NTDBGPT2_embed = GPT2Model.from_pretrained('dracoglacius/NTDB-GPT2')

Some weights of the model checkpoint at dracoglacius/NTDB-GPT2 were not used when initializing GPT2Model: ['lm_head.weight']
- This IS expected if you are initializing GPT2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## ECodes

* E812.0 = Other motor vehicle traffic accident involving collision with motor vehicle injuring driver of motor vehicle other than motorcycle.
* E885.9 = Accidental fall from other slipping tripping or stumbling
* E966.0 = Assault by cutting and piercing instrument
* E965.4 = Assault-firearm NEC:Assault by other and unspecified firearm
* E924.0 = Acc-hot liquid & steam - Accident caused by hot liquids and vapors, including steam

# Adversarial Examples

* E812.0 = Other motor vehicle traffic accident involving collision with motor vehicle injuring driver of motor vehicle other than motorcycle.
* E965.4 = Assault-firearm NEC:Assault by other and unspecified firearm
* E924.0 = Acc-hot liquid & steam - Accident caused by hot liquids and vapors, including steam

1. From the training set obtain the ECode and DCodes for E812.0, E965.4, and E924.0
1. Create 6 sets by mixing the stem and procedure combinations
1. Exclude sets with total token length > 20
1. Create embeddings

* We count these as adversarial example since the stems and procedures themselves come from the training data
* The question is whether the OOD classifier can identify them as OOD based on the sequential information learned

In [8]:
trn_seq = np.load("./data/25k_train_seqs_3_22_E8859_E8120_E9660_E9654_E9240.npy")

## Separate Data

#### Training Data is In Domain Data

In [10]:
e8120_trn_seq = [x for x in trn_seq if 'E812.0' in x] # 5000 items
e9654_trn_seq = [x for x in trn_seq if 'E965.4' in x] # 5000 items
e9240_trn_seq = [x for x in trn_seq if 'E924.0' in x] # 5000 items

#### Adversarial Data

* Stemm: E812.0 + Procedures: E965.4
* Stemm: E812.0 + Procedures: E924.0
* Stemm: E965.4 + Procedures: E812.0
* Stemm: E965.4 + Procedures: E924.0
* Stemm: E924.0 + Procedures: E812.0
* Stemm: E924.0 + Procedures: E965.4

In [13]:
e8120_trn_stem = [x.split('<PSTART>')[0] for x in e8120_trn_seq]
e8120_trn_prcs = [x.split('<PSTART>')[1] for x in e8120_trn_seq]

e9654_trn_stem = [x.split('<PSTART>')[0] for x in e9654_trn_seq]
e9654_trn_prcs = [x.split('<PSTART>')[1] for x in e9654_trn_seq]

e9240_trn_stem = [x.split('<PSTART>')[0] for x in e9240_trn_seq]
e9240_trn_prcs = [x.split('<PSTART>')[1] for x in e9240_trn_seq]

In [24]:
e8120_e9654_adv_seq = [s + '<PSTART>' + p for s,p in zip(e8120_trn_stem, e9654_trn_prcs)]
e8120_e9240_adv_seq = [s + '<PSTART>' + p for s,p in zip(e8120_trn_stem, e9240_trn_prcs)]
e9654_e8120_adv_seq = [s + '<PSTART>' + p for s,p in zip(e9654_trn_stem, e8120_trn_prcs)]
e9654_e9240_adv_seq = [s + '<PSTART>' + p for s,p in zip(e9654_trn_stem, e9240_trn_prcs)]
e9240_e8120_adv_seq = [s + '<PSTART>' + p for s,p in zip(e9240_trn_stem, e8120_trn_prcs)]
e9240_e9654_adv_seq = [s + '<PSTART>' + p for s,p in zip(e9240_trn_stem, e9654_trn_prcs)]

In [37]:
def get_hidden_embeddings(hidden_states, is_train=True, use_last=True):
    if is_train:
        """
        The first hidden_state contains the whole sequence
        """
        _em = torch.squeeze(torch.stack(hidden_states[0]).transpose(0,2), dim=1)
    else:
        _start = torch.squeeze(torch.stack(hidden_states[0]).transpose(0,2), dim=1)
        _hs = torch.stack([torch.reshape(torch.stack(x), [13, 768]) for x in hidden_states[1:]])
        _em = torch.concat([_start, _hs])
        
    if use_last:
        return _em[-1, :, :]
    else:
        return _em

In [38]:
def get_embeddings(sequences, is_train=True, use_last=True):
    token_layer_embeddings = []
    for seq in tqdm.tqdm(sequences):
        seq_ids = NTDBGPT2_tokenizer.encode(seq, return_tensors='pt')
        if len(seq_ids[0]) > 19:
            continue
        out = NTDBGPT2_lm.generate(
            seq_ids,
            do_sample=True,
            #min_length=10,
            #max_length=12,
            #top_p=0.9, 
            top_k=0,
            return_dict_in_generate=True,
            forced_eos_token_id=NTDBGPT2_tokenizer.eos_token_id,
            #repetition_penalty=3.0,
            #length_penalty=1.0,
            #num_return_seqs=1,
            output_hidden_states=True
        )
        token_layer_embeddings.append(get_hidden_embeddings(out.hidden_states, is_train, use_last))
    if use_last:
        return torch.stack(token_layer_embeddings)
    else:
        return token_layer_embeddings

#### Get Sequence Embeddings of All Layers

In [40]:
def clean_seq(seq):
    return ' '.join(x for x in seq.split() if x)

In [41]:
def create_adversarial_embedding_data(ecode1, ecode2, seqs):
    _all_token_layer_embeddings = get_embeddings(seqs, use_last=False)
    np.save(f"./outputs/{ecode1}_{ecode2}_adv_all_em.npy", _all_token_layer_embeddings)
    _end_token_layer_embeddings = torch.stack([x[-1,:,:] for x in _all_token_layer_embeddings])
    np.save(f"./outputs/{ecode1}_{ecode2}_adv_end_em.npy", _end_token_layer_embeddings)

    del _all_token_layer_embeddings
    del _end_token_layer_embeddings
    gc.collect()

In [42]:
create_adversarial_embedding_data('e8120', 'e9654', e8120_e9654_adv_seq)
create_adversarial_embedding_data('e8120', 'e9240', e8120_e9240_adv_seq)
create_adversarial_embedding_data('e9654', 'e8120', e9654_e8120_adv_seq)
create_adversarial_embedding_data('e9654', 'e9240', e9654_e9240_adv_seq)
create_adversarial_embedding_data('e9240', 'e8120', e9240_e8120_adv_seq)
create_adversarial_embedding_data('e9240', 'e9654', e9240_e9654_adv_seq)

100%|███████████████████████████████████████████████████████████| 5000/5000 [10:36<00:00,  7.85it/s]
  arr = np.asanyarray(arr)
  arr = np.asanyarray(arr)
100%|███████████████████████████████████████████████████████████| 5000/5000 [12:58<00:00,  6.42it/s]
100%|███████████████████████████████████████████████████████████| 5000/5000 [12:22<00:00,  6.73it/s]
100%|███████████████████████████████████████████████████████████| 5000/5000 [14:20<00:00,  5.81it/s]
100%|███████████████████████████████████████████████████████████| 5000/5000 [12:16<00:00,  6.79it/s]
100%|███████████████████████████████████████████████████████████| 5000/5000 [11:35<00:00,  7.19it/s]


In [43]:
np.save('outputs/e8120_e9654_adv_seq.npy', e8120_e9654_adv_seq)
np.save('outputs/e8120_e9240_adv_seq.npy', e8120_e9240_adv_seq)
np.save('outputs/e9654_e8120_adv_seq.npy', e9654_e8120_adv_seq)
np.save('outputs/e9654_e9240_adv_seq.npy', e9654_e9240_adv_seq)
np.save('outputs/e9240_e8120_adv_seq.npy', e9240_e8120_adv_seq)
np.save('outputs/e9240_e9654_adv_seq.npy', e9240_e9654_adv_seq)