In [2]:
import os

from bert_deid.create_csv import split_by_overlap
from bert_deid import model, utils
from pytorch_pretrained_bert.modeling import WEIGHTS_NAME, CONFIG_NAME

import pandas as pd

# Load in model

In [6]:
# Load a trained model and config that you have fine-tuned
max_seq_length = 100
step_size = 40

bert_model = model.BertForDEID(
    model_dir="/db/git/bert-deid/models/physionet_goldstandard",
    max_seq_length=max_seq_length,
    token_step_size=step_size
)
bert_model.to('cpu')
bert_model.eval()

Loading model and configuration from /db/git/bert-deid/models/physionet_goldstandard.


BertForDEID(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1)
            )
          )
          (intermediate): BertIntermediate(
           

# Split text into segments and tokenize

In [7]:
text = """
O: 58 YEAR OLD FEMALE ADMITTED IN TRANSFER FROM CALVERT HOSPITAL FOR MENTAL STATUS CHANGES POST FALL AT HOME AND CONTINUED HYPOTENSION AT CALVERT HOSPITAL REQUIRING DOPAMINE; PMH: CAD, S/P MI 1992; LCX PTCA; 3V CABG WITH MVR; CMP; AFIB- AV NODE ABLATION; PERM PACER- DDD MODE; PULM HTN; PVD; NIDDM; HPI: 2 WEEK HISTORY LEG WEAKNESS; 7/22 FOUND BY HUSBAND ON FLOOR- AWAKE, BUT MENTAL STATUS CHANGES; TO CALVERT HOSPITAL ER- TO THEIR ICU; HEAD CT- NEG FOR BLEED; VQ SCAN- NEG FOR PE; ECHO- GLOBAL HYPOKINESIS; EF EST 20%; R/O FOR MI; DIGOXIN TOXIC WITH HYPERKALEMIA- KAYEXALATE, DEXTROSE, INSULIN; RENAL INSUFFICIENCY- BUN 54, CR 2.8; INR 7 ( ON COUMADIN AT HOME); 7/23 AT CALVERT- 2 FFP, 2 UNITS PRBC, VITAMIN K; REFERRED TO GH. 
 ARRIVED IN TRANSFER APPROX. 2130; IN NO MAJOR DISTRESS; DOPAMINE TAPER, THEN DC; NS FLUID BOLUS GIVEN WITH IMPROVEMENT IN BP RANGE; SEE FLOW SHEET SECTION FOR CLINICAL INFORMATION; A: NO HEMODYNAMIC COMPROMISE SINCE TRANSFER; TOLERATING DOPAMINE DC; P: TREND BP RANGE; OBSERVE FOR PRECIPITOUS HYPOTENSION.
"""

# create splits
examples = split_by_overlap(
    text, bert_model.tokenizer,
    token_step_size=bert_model.token_step_size,
    max_seq_len=bert_model.max_seq_length
)

for e, example in enumerate(examples):
    # track offsets in tokenization
    tokens, tokens_sw, tokens_idx = bert_model.tokenizer.tokenize_with_index(
        example[3])
    print("Sentence {}: {}".format(e+1, " ".join([str(x) for x in tokens])))

Sentence 1: O : 58 Y ##EA ##R O ##LD F ##EM ##AL ##E AD ##MI ##TT ##ED IN T ##RA ##NS ##F ##ER F ##RO ##M CA ##L ##VE ##RT H ##OS ##PI ##TA ##L F ##OR ME ##NT ##AL ST ##AT ##US CH ##AN ##GE ##S P ##OS ##T FA ##LL AT H ##OM ##E AND CO ##NT ##IN ##UE ##D H ##YP ##OT ##EN ##SI ##ON AT CA ##L ##VE ##RT H ##OS ##PI ##TA ##L R ##E ##Q ##UI ##RI ##NG D ##OP ##AM ##IN ##E ; PM ##H : CA ##D , S / P MI 1992
Sentence 2: ST ##AT ##US CH ##AN ##GE ##S P ##OS ##T FA ##LL AT H ##OM ##E AND CO ##NT ##IN ##UE ##D H ##YP ##OT ##EN ##SI ##ON AT CA ##L ##VE ##RT H ##OS ##PI ##TA ##L R ##E ##Q ##UI ##RI ##NG D ##OP ##AM ##IN ##E ; PM ##H : CA ##D , S / P MI 1992 ; L ##C ##X PT ##CA ; 3 ##V CA ##B ##G W ##IT ##H MV ##R ; C ##MP ; A ##FI ##B - A ##V NO ##DE AB ##LA ##TI ##ON ; P ##ER ##M PA ##CE ##R
Sentence 3: R ##E ##Q ##UI ##RI ##NG D ##OP ##AM ##IN ##E ; PM ##H : CA ##D , S / P MI 1992 ; L ##C ##X PT ##CA ; 3 ##V CA ##B ##G W ##IT ##H MV ##R ; C ##MP ; A ##FI ##B - A ##V NO ##DE AB ##LA ##TI ##ON ; P ##E

# Annotate text

In [8]:
df = bert_model.annotate(text, document_id='1-1')
display(df)

Unnamed: 0,document_id,annotation_id,annotator,start,stop,entity,entity_type,comment,confidence
0,1-1,bert.0.10,bert-base-cased,49,56,CALVERT,LOCATION,,7.490695
1,1-1,bert.0.24,bert-base-cased,139,146,CALVERT,LOCATION,,7.575217
2,1-1,bert.0.37,bert-base-cased,193,197,1992,DATE,,6.841505
3,1-1,bert.1.10,bert-base-cased,139,146,CALVERT,LOCATION,,7.450922
4,1-1,bert.1.23,bert-base-cased,193,197,1992,DATE,,7.286679
5,1-1,bert.2.11,bert-base-cased,193,197,1992,DATE,,7.097697
6,1-1,bert.3.28,bert-base-cased,334,335,7,DATE,,8.904808
7,1-1,bert.3.29,bert-base-cased,335,336,/,DATE,,8.941896
8,1-1,bert.3.30,bert-base-cased,336,338,22,DATE,,9.017146
9,1-1,bert.4.9,bert-base-cased,334,335,7,DATE,,8.854787


# Pool model annotations for overlapping segments

In [10]:
if max_seq_length > step_size:
    df = bert_model.pool_annotations(df)
    display(df)
else:
    print('Non-overlapping segments - no pooling performed.')

Unnamed: 0,document_id,annotation_id,annotator,start,stop,entity,entity_type,comment,confidence
0,1-1,bert.0.10,bert-base-cased,49,56,CALVERT,LOCATION,,7.490695
1,1-1,bert.0.24,bert-base-cased,139,146,CALVERT,LOCATION,,7.575217
2,1-1,bert.1.23,bert-base-cased,193,197,1992,DATE,,7.286679
3,1-1,bert.3.28,bert-base-cased,334,335,7,DATE,,8.904808
4,1-1,bert.3.29,bert-base-cased,335,336,/,DATE,,8.941896
5,1-1,bert.3.30,bert-base-cased,336,338,22,DATE,,9.017146
6,1-1,bert.4.26,bert-base-cased,403,410,CALVERT,LOCATION,,7.847585
7,1-1,bert.8.22,bert-base-cased,664,665,7,DATE,,8.921864
8,1-1,bert.8.23,bert-base-cased,665,666,/,DATE,,9.038324
9,1-1,bert.8.24,bert-base-cased,666,668,23,DATE,,8.955135


# Harmonize entity types and merge nearby annotations

In [11]:
# merges entity types + combines annotations <= 1 character apart
df = utils.simplify_bert_ann(df, text, lowercase=True, dist=1)
display(df)

Unnamed: 0,document_id,annotation_id,annotator,start,stop,entity,entity_type,comment,confidence
2,1-1,bert.1.23,bert-base-cased,193,197,1992,date,,7.286679
3,1-1,bert.3.28,bert-base-cased,334,338,7/22,date,,8.904808
7,1-1,bert.8.22,bert-base-cased,664,668,7/23,date,,8.921864
0,1-1,bert.0.10,bert-base-cased,49,56,CALVERT,location,,7.490695
1,1-1,bert.0.24,bert-base-cased,139,146,CALVERT,location,,7.575217
6,1-1,bert.4.26,bert-base-cased,403,410,CALVERT,location,,7.847585
10,1-1,bert.8.26,bert-base-cased,672,679,CALVERT,location,,7.875831
11,1-1,bert.8.40,bert-base-cased,725,727,GH,location,,8.485923


# (Optional) Load ground truth and evaluate performance

In [13]:
gs_fn = '/db/git/deid-gs/physionet_goldstandard/train/ann/1-1.gs'
gs = pd.read_csv(gs_fn, header=0,
                 dtype={'entity': str,
                        'entity_type': str})

# fix entities - lower case and group
gs = utils.combine_entity_types(gs, lowercase=True)

# run comparison looking for exact/partial/misses
cmp_ann = utils.compare_single_doc(gs, df)
# add in the text/start/stop from gold standard annot
cmp_ann = cmp_ann.merge(gs[['annotation_id',
                            'start', 'stop',
                            'entity_type', 'entity']],
                        how='left', on='annotation_id')
display(cmp_ann)

Unnamed: 0,document_id,annotation_id,exact,partial,missed,span,start,stop,entity_type,entity
0,1-1,1,0,1,0,48 49,48,55,location,CALVERT
1,1-1,2,0,1,0,138 139,138,145,location,CALVERT
2,1-1,3,0,1,0,192 193,192,196,dateyear,1992
3,1-1,4,0,1,0,333 334,333,337,date,7/22
4,1-1,5,0,1,0,402 403,402,409,location,CALVERT
5,1-1,6,0,1,0,663 664,663,667,date,7/23
6,1-1,7,0,1,0,671 672,671,678,location,CALVERT
7,1-1,8,0,1,0,724 725,724,726,location,GH
