# Comparing different embeddigs for biomedical sequence labelling

In [2]:
!pip install -e vadim-ml-tools

Obtaining file:///notebook/code/biomed_ie/src/vadim-ml-tools
Installing collected packages: vadim-ml
  Found existing installation: vadim-ml 0.1.0
    Uninstalling vadim-ml-0.1.0:
      Successfully uninstalled vadim-ml-0.1.0
  Running setup.py develop for vadim-ml
Successfully installed vadim-ml
[33mYou are using pip version 18.1, however version 19.1.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [3]:
!pip install flair tinydb hyperopt nltk allennlp

Collecting s3transfer<0.3.0,>=0.2.0 (from awscli>=1.11.91->allennlp)
[?25l  Downloading https://files.pythonhosted.org/packages/d7/de/5737f602e22073ecbded7a0c590707085e154e32b68d86545dcc31004c02/s3transfer-0.2.0-py2.py3-none-any.whl (69kB)
[K    100% |████████████████████████████████| 71kB 373kB/s ta 0:00:01
Collecting botocore==1.12.151 (from awscli>=1.11.91->allennlp)


[?25l  Downloading https://files.pythonhosted.org/packages/57/89/39e9d8a45ff3290c41d47065ec0abc9936925a4bc88bb488e07897d9f38d/botocore-1.12.151-py2.py3-none-any.whl (5.4MB)
[K    100% |████████████████████████████████| 5.4MB 5.4MB/s eta 0:00:01    66% |█████████████████████▍          | 3.6MB 3.8MB/s eta 0:00:01


[31mmoto 1.3.8 has requirement boto3>=1.9.86, but you'll have boto3 1.9.59 which is incompatible.[0m
[31mboto3 1.9.59 has requirement s3transfer<0.2.0,>=0.1.10, but you'll have s3transfer 0.2.0 which is incompatible.[0m
Installing collected packages: botocore, s3transfer
  Found existing installation: botocore 1.12.59
    Uninstalling botocore-1.12.59:
      Successfully uninstalled botocore-1.12.59
  Found existing installation: s3transfer 0.1.13
    Uninstalling s3transfer-0.1.13:
      Successfully uninstalled s3transfer-0.1.13
Successfully installed botocore-1.12.151 s3transfer-0.2.0
[33mYou are using pip version 18.1, however version 19.1.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [1]:
import pandas as pd

In [2]:
dataset_path = '../data/all_attributes.json'
dataset = pd.read_json(dataset_path)
print(dataset.shape)
dataset.head()

(37038, 5)


Unnamed: 0,HYPERTENSION,CAD,DIABETES,texts,doc_ids
0,[],[],[],Record date: 2074-12-05\n\n \n \n \n \n \n \n ...,0
1,[],[],[],"228 Caldwell Road\nColorado City, NY 43414\n...",0
10,[],[],[],", simvastatin 10 mg po q.d.,\namlodipine 5 mg ...",0
100,[],[],[],SMOKED UNTIL 8/2/81.,1
1000,[],[],[],She had no respiratory symptoms of dyspnea or ...,21


In [3]:
import tinydb
results_db = tinydb.TinyDB('../results.json')

## Preprocessing

In [4]:
from vadim_ml.nlp import char_annotations_as_token_annotations, token_span_to_bio
from vadim_ml.memoize import disk_memoize
from nltk.tokenize import TreebankWordTokenizer

tok = TreebankWordTokenizer()

def tokenize_row(row):
    text = row['texts']
    token_spans = list(tok.span_tokenize(text))
    tokens = [text[s:e] for s, e in token_spans]

    text_dict = {
        'text': tokens,
    }

    for pat in ('HYPERTENSION', 'CAD', 'DIABETES'):
        spans = char_annotations_as_token_annotations(token_spans, row[pat])
        text_dict[pat] = token_span_to_bio(tokens, spans)

    return text_dict

In [5]:
from flair.data_fetcher import NLPTaskDataFetcher
import os

def produce_column_corpus(sentence_dicts, folds, column_fields, path): 
    os.makedirs(path, exist_ok=True)
    
    files = {}
    
    for text_dict, fold in zip(sentence_dicts, folds):
        try:
            file = files[fold]
        except KeyError:
            files[fold] = open(os.path.join(path, fold + '.txt'), 'w')
        
        for line in zip(*(text_dict[c] for c in column_fields)):
            files[fold].write('\t'.join(line) + '\n')
            
        files[fold].write('\n')
        
    for file in files.values():
        file.close()

In [6]:
import random

def train_dev_test_distr():
    x = random.random()
    if x < 0.2:
        return 'test'
    elif x < 0.4:
        return 'dev'
    else:
        return 'train'

In [7]:
from tqdm import tqdm_notebook

corpus_path = '../data/columncorpus'
columns = {0: 'text', 1: 'HYPERTENSION', 2: 'CAD', 3: 'DIABETES'}

def column_corpus():
    try:
        return NLPTaskDataFetcher.load_column_corpus(corpus_path, columns, 
                                                     train_file='train.txt',
                                                     test_file='test.txt',
                                                     dev_file='dev.txt')
    except FileNotFoundError:
        irows = dataset.iterrows()
        texts = (tokenize_row(row) for idx, row in tqdm_notebook(irows))
        folds = (train_dev_test_distr() for i in iter(int, 1))
        produce_column_corpus(texts, folds, columns.values(), corpus_path)
        return NLPTaskDataFetcher.load_column_corpus(corpus_path, columns, 
                                                     train_file='train.txt',
                                                     test_file='test.txt',
                                                     dev_file='dev.txt')

In [8]:
corpus = column_corpus()

2019-05-31 02:09:58,478 Reading data from ../data/columncorpus
2019-05-31 02:09:58,479 Train: ../data/columncorpus/train.txt
2019-05-31 02:09:58,480 Dev: ../data/columncorpus/dev.txt
2019-05-31 02:09:58,481 Test: ../data/columncorpus/test.txt


## Flair sequence tagger

Flair embeddings trained on PubMed abstracts with BiLSTM-CRF

In [9]:
# The dictionary is the same fro all 3 pathology
tag_dictionary = corpus.make_tag_dictionary('HYPERTENSION')

In [13]:
embeddings = StackedEmbeddings([WordEmbeddings('glove'), 
                                FlairEmbeddings('pubmed-forward'), 
                                FlairEmbeddings('pubmed-backward'),
                               ])

2019-05-13 09:33:05,011 https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4.1/pubmed-2015-bw-lm.pt not found in cache, downloading to /tmp/tmp5d92gn7x


100%|██████████| 111081366/111081366 [00:11<00:00, 10068748.57B/s]

2019-05-13 09:33:16,354 copying /tmp/tmp5d92gn7x to cache at /root/.flair/embeddings/pubmed-2015-bw-lm.pt





2019-05-13 09:33:16,982 removing temp file /tmp/tmp5d92gn7x


In [10]:
from flair.embeddings import WordEmbeddings, FlairEmbeddings, StackedEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from flair.training_utils import EvaluationMetric

def train_tagger(pathology, embeddings, folder):
    tagger = SequenceTagger(hidden_size=256,
                           embeddings=embeddings,
                           tag_dictionary=tag_dictionary,
                           tag_type=pathology,
                           use_crf=True)
    
    

    trainer = ModelTrainer(tagger, corpus)
    trainer.train(f'../models/{folder}',
                  EvaluationMetric.MICRO_F1_SCORE,
                  learning_rate=0.1,
                  mini_batch_size=32,
                  max_epochs=150,
                  checkpoint=True)
    
    return tagger

In [None]:
taggers = {}

for pathology in ('HYPERTENSION', 'CAD', 'DIABETES'):
    print(pathology)
    
    taggers[pathology] = train_tagger(pathology, embeddings, f'flair_{pathology.lower()}')

HYPERTENSION
2019-05-13 09:56:20,619 ----------------------------------------------------------------------------------------------------
2019-05-13 09:56:20,620 Evaluation method: MICRO_F1_SCORE
2019-05-13 09:56:20,622 ----------------------------------------------------------------------------------------------------
2019-05-13 09:56:21,381 epoch 1 - iter 0/693 - loss 19.90345383
2019-05-13 09:56:41,181 epoch 1 - iter 69/693 - loss 0.79330848
2019-05-13 09:57:02,946 epoch 1 - iter 138/693 - loss 0.63059303
2019-05-13 09:57:25,740 epoch 1 - iter 207/693 - loss 0.52976911
2019-05-13 09:57:48,526 epoch 1 - iter 276/693 - loss 0.47508759
2019-05-13 09:58:08,615 epoch 1 - iter 345/693 - loss 0.43311270
2019-05-13 09:58:32,522 epoch 1 - iter 414/693 - loss 0.41281813
2019-05-13 09:58:55,792 epoch 1 - iter 483/693 - loss 0.38964588
2019-05-13 09:59:20,066 epoch 1 - iter 552/693 - loss 0.37273970
2019-05-13 09:59:41,577 epoch 1 - iter 621/693 - loss 0.35172129
2019-05-13 10:00:04,208 epoch 1

2019-05-13 10:35:51,383 EPOCH 7 done: loss 0.1357 - lr 0.1000 - bad epochs 0
2019-05-13 10:36:50,135 DEV  : loss 0.11631372 - f-score 0.3905 - acc 0.2426
2019-05-13 10:37:52,571 TEST : loss 0.09813494 - f-score 0.4228 - acc 0.2681
2019-05-13 10:38:05,766 ----------------------------------------------------------------------------------------------------
2019-05-13 10:38:06,066 epoch 8 - iter 0/693 - loss 0.04144216
2019-05-13 10:38:28,059 epoch 8 - iter 69/693 - loss 0.10165724
2019-05-13 10:38:54,116 epoch 8 - iter 138/693 - loss 0.11263817
2019-05-13 10:39:20,243 epoch 8 - iter 207/693 - loss 0.11782995
2019-05-13 10:39:44,428 epoch 8 - iter 276/693 - loss 0.12288016
2019-05-13 10:40:06,782 epoch 8 - iter 345/693 - loss 0.13138471
2019-05-13 10:40:31,997 epoch 8 - iter 414/693 - loss 0.13184896
2019-05-13 10:40:53,367 epoch 8 - iter 483/693 - loss 0.13542160
2019-05-13 10:41:17,386 epoch 8 - iter 552/693 - loss 0.13681434
2019-05-13 10:41:38,869 epoch 8 - iter 621/693 - loss 0.136603

2019-05-13 11:17:26,449 epoch 14 - iter 690/693 - loss 0.11696168
2019-05-13 11:17:27,011 ----------------------------------------------------------------------------------------------------
2019-05-13 11:17:27,012 EPOCH 14 done: loss 0.1168 - lr 0.1000 - bad epochs 2
2019-05-13 11:18:24,274 DEV  : loss 0.10381896 - f-score 0.5406 - acc 0.3705
2019-05-13 11:19:25,340 TEST : loss 0.09096314 - f-score 0.5580 - acc 0.3869
2019-05-13 11:19:38,507 ----------------------------------------------------------------------------------------------------
2019-05-13 11:19:39,162 epoch 15 - iter 0/693 - loss 0.06764615
2019-05-13 11:20:02,565 epoch 15 - iter 69/693 - loss 0.15816344
2019-05-13 11:20:22,187 epoch 15 - iter 138/693 - loss 0.13901940
2019-05-13 11:20:43,253 epoch 15 - iter 207/693 - loss 0.12334205
2019-05-13 11:21:03,930 epoch 15 - iter 276/693 - loss 0.11509385
2019-05-13 11:21:23,057 epoch 15 - iter 345/693 - loss 0.11788456
2019-05-13 11:21:43,757 epoch 15 - iter 414/693 - loss 0.12

2019-05-13 11:57:00,615 epoch 21 - iter 483/693 - loss 0.10715379
2019-05-13 11:57:25,170 epoch 21 - iter 552/693 - loss 0.11048361
2019-05-13 11:57:49,367 epoch 21 - iter 621/693 - loss 0.11282889
2019-05-13 11:58:09,877 epoch 21 - iter 690/693 - loss 0.11466386
2019-05-13 11:58:10,546 ----------------------------------------------------------------------------------------------------
2019-05-13 11:58:10,547 EPOCH 21 done: loss 0.1144 - lr 0.1000 - bad epochs 3
2019-05-13 11:59:08,465 DEV  : loss 0.12368022 - f-score 0.5330 - acc 0.3634
2019-05-13 12:00:09,716 TEST : loss 0.10091684 - f-score 0.5683 - acc 0.3969
Epoch    20: reducing learning rate of group 0 to 5.0000e-02.
2019-05-13 12:00:16,014 ----------------------------------------------------------------------------------------------------
2019-05-13 12:00:16,495 epoch 22 - iter 0/693 - loss 0.00104624
2019-05-13 12:00:39,381 epoch 22 - iter 69/693 - loss 0.08153063
2019-05-13 12:01:06,673 epoch 22 - iter 138/693 - loss 0.083781

In [11]:
from vadim_ml.metrics import binary_classification_report
from vadim_ml.io import load_file

def detection_report(model_name, text=True):
    y_true = []
    y_pred = []

    for sentence in load_file(f'../models/{model_name}/test.tsv').split('\n\n'):
        true_bio_tags = []
        pred_bio_tags = []
        
        for token in sentence.split('\n'):
            if token:
                word, true, pred, prob = token.split(' ')

                true_bio_tags.append(true)
                pred_bio_tags.append(pred)                

        y_true.append('B' in true_bio_tags)
        y_pred.append('B' in pred_bio_tags)
            
    return binary_classification_report(y_true, y_pred, text)

In [12]:
print(detection_report('flair_hypertension'))

true negatives: 7161
false positives: 28
false negatives: 91
true positives: 149
kappa: 0.7065812615149436
precision: 0.8418079096045198
recall: 0.6208333333333333
f1: 0.7146282973621104



In [18]:
print(detection_report('flair_cad'))

true negatives: 7219
false positives: 40
false negatives: 73
true positives: 97
kappa: 0.624247635425623
precision: 0.708029197080292
recall: 0.5705882352941176
f1: 0.6319218241042346



In [14]:
print(detection_report('flair_diabetes'))

true negatives: 7179
false positives: 34
false negatives: 38
true positives: 178
kappa: 0.8267865446815904
precision: 0.839622641509434
recall: 0.8240740740740741
f1: 0.8317757009345794



## Testing out different model parameters and embeddings

In [10]:
from flair.hyperparameter.parameter import SEQUENCE_TAGGER_PARAMETERS

SEQUENCE_TAGGER_PARAMETERS

['embeddings',
 'hidden_size',
 'rnn_layers',
 'use_crf',
 'use_rnn',
 'dropout',
 'locked_dropout',
 'word_dropout']

In [13]:
from flair.hyperparameter.param_selection import SearchSpace, Parameter
from flair.embeddings import BertEmbeddings, ELMoEmbeddings, WordEmbeddings, BytePairEmbeddings, FlairEmbeddings, StackedEmbeddings

embeddings_to_try = {
    'fasttext': lambda: WordEmbeddings('en'),
    'elmo-general': lambda: ELMoEmbeddings('original'),
    'elmo-pubmed': lambda: ELMoEmbeddings('pubmed')
    # Not enough memory
    #'bert': lambda: BertEmbeddings('bert-base-cased'),
    #'bert-elmo-pubmed': lambda: StackedEmbeddings([ BertEmbeddings('bert-base-cased'), ELMoEmbeddings('pubmed') ]),
    #'bert-flair-pubmed': lambda: StackedEmbeddings([ BertEmbeddings('bert-base-cased'), FlairEmbeddings('pubmed-forward'), FlairEmbeddings('pubmed-backward')])
}

In [None]:
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from flair.training_utils import EvaluationMetric

for pathology in ('HYPERTENSION', 'CAD', 'DIABETES'):
    for emb_name, emb in embeddings_to_try.items():
        
        print(f'Training the model for {pathology} with {emb_name} embeddings')

        train_tagger(pathology, emb(), f'hyperopt/{emb_name}_{pathology.lower()}')

Training the model for CAD with elmo-general embeddings
2019-05-22 10:42:33,641 ----------------------------------------------------------------------------------------------------
2019-05-22 10:42:33,642 Evaluation method: MICRO_F1_SCORE
2019-05-22 10:42:33,644 ----------------------------------------------------------------------------------------------------
2019-05-22 10:42:34,452 epoch 1 - iter 0/693 - loss 5.73918247
2019-05-22 10:43:33,042 epoch 1 - iter 69/693 - loss 0.82185352
2019-05-22 10:44:26,810 epoch 1 - iter 138/693 - loss 0.75585596
2019-05-22 10:45:23,071 epoch 1 - iter 207/693 - loss 0.77252042
2019-05-22 10:46:13,118 epoch 1 - iter 276/693 - loss 0.71987132
2019-05-22 10:47:06,179 epoch 1 - iter 345/693 - loss 0.69413630
2019-05-22 10:47:55,838 epoch 1 - iter 414/693 - loss 0.69111095
2019-05-22 10:48:42,006 epoch 1 - iter 483/693 - loss 0.64716698
2019-05-22 10:49:36,169 epoch 1 - iter 552/693 - loss 0.62729253
2019-05-22 10:50:18,803 epoch 1 - iter 621/693 - loss 

2019-05-22 11:27:54,955 ----------------------------------------------------------------------------------------------------
2019-05-22 11:27:54,957 EPOCH 7 done: loss 0.2796 - lr 0.1000 - bad epochs 0
2019-05-22 11:28:50,305 DEV  : loss 0.21456926 - f-score 0.3457 - acc 0.2090
2019-05-22 11:29:49,171 TEST : loss 0.18661231 - f-score 0.2729 - acc 0.1580
2019-05-22 11:29:51,654 ----------------------------------------------------------------------------------------------------
2019-05-22 11:29:51,823 epoch 8 - iter 0/693 - loss 0.46480909
2019-05-22 11:30:12,799 epoch 8 - iter 69/693 - loss 0.29046286
2019-05-22 11:30:37,217 epoch 8 - iter 138/693 - loss 0.26829081


In [None]:
train_tagger('DIABETES', embeddings_to_try['elmo-general'](), f'hyperopt/elmo-general_diabetes')

2019-05-31 10:45:29,651 ----------------------------------------------------------------------------------------------------
2019-05-31 10:45:29,652 Evaluation method: MICRO_F1_SCORE
2019-05-31 10:45:29,654 ----------------------------------------------------------------------------------------------------
2019-05-31 10:45:30,105 epoch 1 - iter 0/693 - loss 33.67508698
2019-05-31 10:46:28,618 epoch 1 - iter 69/693 - loss 1.69672735
2019-05-31 10:47:28,012 epoch 1 - iter 138/693 - loss 1.03634923
2019-05-31 10:48:27,931 epoch 1 - iter 207/693 - loss 0.83066240
2019-05-31 10:49:14,354 epoch 1 - iter 276/693 - loss 0.77436809
2019-05-31 10:50:14,665 epoch 1 - iter 345/693 - loss 0.70632665
2019-05-31 10:51:08,388 epoch 1 - iter 414/693 - loss 0.68452465
2019-05-31 10:52:05,456 epoch 1 - iter 483/693 - loss 0.64053016
2019-05-31 10:52:51,041 epoch 1 - iter 552/693 - loss 0.64900328
2019-05-31 10:53:35,174 epoch 1 - iter 621/693 - loss 0.61229206
2019-05-31 10:54:17,447 epoch 1 - iter 690/6

2019-05-31 11:31:16,986 EPOCH 7 done: loss 0.2353 - lr 0.1000 - bad epochs 2
2019-05-31 11:32:19,887 DEV  : loss 0.19348748 - f-score 0.4019 - acc 0.2514
2019-05-31 11:33:25,908 TEST : loss 0.19318265 - f-score 0.4696 - acc 0.3069
2019-05-31 11:33:31,302 ----------------------------------------------------------------------------------------------------
2019-05-31 11:33:31,602 epoch 8 - iter 0/693 - loss 0.00890633
2019-05-31 11:33:55,116 epoch 8 - iter 69/693 - loss 0.19240644
2019-05-31 11:34:16,977 epoch 8 - iter 138/693 - loss 0.22215896
2019-05-31 11:34:35,068 epoch 8 - iter 207/693 - loss 0.25610739
2019-05-31 11:34:50,867 epoch 8 - iter 276/693 - loss 0.22425647
2019-05-31 11:35:08,592 epoch 8 - iter 345/693 - loss 0.23474525
2019-05-31 11:35:30,348 epoch 8 - iter 414/693 - loss 0.23975534
2019-05-31 11:35:51,454 epoch 8 - iter 483/693 - loss 0.25055043
2019-05-31 11:36:11,257 epoch 8 - iter 552/693 - loss 0.24791340
2019-05-31 11:36:28,661 epoch 8 - iter 621/693 - loss 0.253737

In [16]:
print(detection_report('hyperopt/fasttext_hypertension'))

true negatives: 7123
false positives: 66
false negatives: 51
true positives: 189
kappa: 0.7554981708629223
precision: 0.7411764705882353
recall: 0.7875
f1: 0.7636363636363638



In [30]:
print(detection_report('hyperopt/fasttext_cad'))

true negatives: 7224
false positives: 35
false negatives: 72
true positives: 98
kappa: 0.6396251281300342
precision: 0.7368421052631579
recall: 0.5764705882352941
f1: 0.6468646864686468



In [19]:
print(detection_report('hyperopt/fasttext_diabetes'))

true negatives: 7184
false positives: 29
false negatives: 47
true positives: 169
kappa: 0.8111736514529767
precision: 0.8535353535353535
recall: 0.7824074074074074
f1: 0.8164251207729469



In [17]:
print(detection_report('hyperopt/elmo-general_hypertension'))

true negatives: 7144
false positives: 45
false negatives: 29
true positives: 211
kappa: 0.8456595124405373
precision: 0.82421875
recall: 0.8791666666666667
f1: 0.8508064516129031



In [20]:
print(detection_report('hyperopt/elmo-general_cad'))

true negatives: 7232
false positives: 27
false negatives: 81
true positives: 89
kappa: 0.6152353622148663
precision: 0.7672413793103449
recall: 0.5235294117647059
f1: 0.6223776223776224



In [17]:
print(detection_report('hyperopt/elmo-general_diabetes'))

true negatives: 7176
false positives: 37
false negatives: 45
true positives: 171
kappa: 0.8009248245120978
precision: 0.8221153846153846
recall: 0.7916666666666666
f1: 0.8066037735849055

