In [1]:
BASE_PATH = '/home/wish/'

In [2]:
import sys
sys.path.insert(0, BASE_PATH + "projects/MedCAT/")

%load_ext autoreload
%autoreload 2

In [3]:
import json
import pandas as pd
import numpy as np
import os

import datasets
from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments

from medcat.datasets import medcat_ner
from medcat.datasets.tokenizer_ner import TokenizerNER
from medcat.datasets.data_collator import CollateAndPadNER

from medcat.cdb import CDB
from medcat.vocab import Vocab
from medcat.cat import CAT

In [41]:
DATA_PATH = BASE_PATH + "data/medmentions/medmentions.json"

In [42]:
data = json.load(open(DATA_PATH))

In [46]:
cnts = {}
for p in data['projects']:
    for d in p['documents']:
        for a in d['annotations']:
            cnts[a['cui']] = cnts.get(a['cui'], 0) + 1

In [47]:
cnts

{'C0030705': 5897,
 'C2603343': 2140,
 'C0441889': 1041,
 'C0936012': 1192,
 'C1511726': 1001,
 'C1274040': 1044,
 'C0332281': 1300,
 'C0243095': 1925,
 'C0441833': 1203,
 'C0017337': 1032,
 'C0087111': 1662,
 'C0392762': 1236,
 'C0205556': 1248}

In [16]:
# Get the CDB (If you do not have the mimic CDB, please use the MedMentions available in the medcat repository)
#cdb = CDB.load(BASE_PATH + "data/models/cdb_mimic_md_21-April-2021.dat")
cdb = CDB.load(BASE_PATH + "data/medcat_paper/cdb_pubmed_unsupervised.dat")

In [17]:
# Get the vocab
vocab = Vocab.load(BASE_PATH + "data/vocabs/vocab.dat")

In [18]:
cdb.config.ner['upper_case_limit_len'] = 2
cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab)

### Subset to frequent concepts - for testing

In [44]:
MIN_FREQ = 1000
for p in data['projects']:
    for d in p['documents']:
        anns = []
        for a in d['annotations']:
            if cnts[a['cui']] > MIN_FREQ:
                anns.append(a)
        d['annotations'] = anns

In [48]:
DATA_PATH = BASE_PATH + "data/medmentions/medmentions_only_above_300.json"
json.dump(data, open(DATA_PATH, 'w'))

In [49]:
dataset = datasets.load_dataset(os.path.abspath(medcat_ner.__file__), 
                                data_files=DATA_PATH, 
                                split=datasets.Split.TRAIN,
                                cache_dir='/tmp/')

Using custom data configuration default-825e1537563bb329


Downloading and preparing dataset med_catner/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /tmp/med_catner/default-825e1537563bb329/0.0.0/59a4c7ba592923c039fe649a9f8acd4b10c0ef79e1a1551554b7744ce009ad57...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset med_catner downloaded and prepared to /tmp/med_catner/default-825e1537563bb329/0.0.0/59a4c7ba592923c039fe649a9f8acd4b10c0ef79e1a1551554b7744ce009ad57. Subsequent calls will reuse this data.


In [50]:
dataset

Dataset({
    features: ['id', 'text', 'ent_starts', 'ent_ends', 'ent_cuis'],
    num_rows: 4392
})

In [51]:
hf_tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
id2type = {}
for i in range(hf_tokenizer.vocab_size):
    id2type[i] = 'sub' if hf_tokenizer.convert_ids_to_tokens(i).startswith("##") else 'start'
tokenizer = TokenizerNER(hf_tokenizer, id2type=id2type)

In [52]:
encoded_dataset = dataset.map(
        lambda examples: tokenizer.encode(examples, ignore_subwords=True),
        batched=True,
        remove_columns=['ent_cuis', 'ent_ends', 'ent_starts', 'text'])

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [53]:
encoded_dataset

Dataset({
    features: ['id', 'input_ids', 'labels'],
    num_rows: 4935
})

In [54]:
model = AutoModelForTokenClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(tokenizer.label_map))

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint 

In [55]:
encoded_dataset = encoded_dataset.train_test_split(test_size = 0.2)

In [56]:
collate_fn = CollateAndPadNER(hf_tokenizer.pad_token_id)

In [57]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=10,              # total number of training epochs
    per_device_train_batch_size=4,  # batch size per device during training
    per_device_eval_batch_size=4,   # batch size for evaluation
    weight_decay=0.1431478776404838,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,
    eval_steps=500,
    warmup_steps=15,
    learning_rate= 4.4670352057797207e-05,
    eval_accumulation_steps=1,
    do_eval=True,
    evaluation_strategy='steps',
    load_best_model_at_end=True,
)

In [58]:
from sklearn.metrics import classification_report
def metrics(p):
    preds = np.argmax(p.predictions, axis=2)
    # Ignore predictions where label == -100, padding
    preds[np.where(p.label_ids == -100)] = -100
    print(classification_report(np.reshape(p.label_ids, -1), np.reshape(preds, -1)))
    return {'none': 0}

In [59]:
trainer = Trainer(
    model=model,                         
    args=training_args,                 
    train_dataset=encoded_dataset['train'],       
    eval_dataset=encoded_dataset['test'],     
    compute_metrics=metrics,
    data_collator=collate_fn,
    tokenizer=None
)

In [60]:
trainer.train()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwish[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.33 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Step,Training Loss,Validation Loss,None,Runtime,Samples Per Second
500,0.0745,0.058395,0,7.172,137.618
1000,0.0522,0.053738,0,7.1855,137.359
1500,0.0457,0.054341,0,7.1565,137.917
2000,0.0432,0.049285,0,7.1961,137.158
2500,0.036,0.05229,0,7.1753,137.555
3000,0.0332,0.059053,0,7.1992,137.098
3500,0.024,0.053852,0,7.1972,137.136
4000,0.025,0.060464,0,7.204,137.006
4500,0.017,0.060763,0,7.2144,136.81
5000,0.0156,0.065727,0,7.1956,137.166


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

        -100       1.00      1.00      1.00    177928
           0       0.99      0.99      0.99    248269
           1       1.00      1.00      1.00     73507
           2       0.86      0.88      0.87      1282
           3       0.60      0.37      0.46       463
           4       0.47      0.62      0.54       240
           5       0.44      0.74      0.55       282
           6       0.32      0.33      0.32       203
           7       0.20      0.00      0.00       790
           8       0.63      0.97      0.76       435
           9       0.33      0.65      0.44       248
          10       0.67      0.26      0.37       226
          11       0.48      0.40      0.44       380
          12       0.00      0.00      0.00       451
          13       0.00      0.00      0.00       395
          14       0.68      0.86      0.76       245

    accuracy                           0.99    505344
   macro avg       0.54   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

        -100       1.00      1.00      1.00    177928
           0       0.99      0.99      0.99    248269
           1       1.00      1.00      1.00     73507
           2       0.86      0.91      0.88      1282
           3       0.51      0.81      0.63       463
           4       0.52      0.88      0.65       240
           5       0.54      0.65      0.59       282
           6       0.28      0.95      0.43       203
           7       0.50      0.10      0.16       790
           8       0.67      0.95      0.79       435
           9       0.42      0.48      0.45       248
          10       0.61      0.37      0.46       226
          11       0.40      0.63      0.49       380
          12       0.00      0.00      0.00       451
          13       0.00      0.00      0.00       395
          14       0.72      0.83      0.77       245

    accuracy                           0.99    505344
   macro avg       0.56   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

        -100       1.00      1.00      1.00    177928
           0       0.99      1.00      0.99    248269
           1       1.00      1.00      1.00     73507
           2       0.89      0.89      0.89      1282
           3       0.58      0.10      0.17       463
           4       0.56      0.10      0.18       240
           5       0.63      0.38      0.47       282
           6       0.00      0.00      0.00       203
           7       0.54      0.08      0.14       790
           8       0.68      0.94      0.79       435
           9       0.00      0.00      0.00       248
          10       0.67      0.28      0.39       226
          11       0.69      0.25      0.37       380
          12       0.00      0.00      0.00       451
          13       0.00      0.00      0.00       395
          14       0.72      0.82      0.77       245

    accuracy                           0.99    505344
   macro avg       0.56   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

        -100       1.00      1.00      1.00    177928
           0       0.99      1.00      0.99    248269
           1       1.00      1.00      1.00     73507
           2       0.88      0.90      0.89      1282
           3       0.57      0.65      0.61       463
           4       0.58      0.40      0.48       240
           5       0.64      0.51      0.56       282
           6       0.43      0.30      0.35       203
           7       0.45      0.08      0.13       790
           8       0.69      0.94      0.79       435
           9       0.47      0.10      0.17       248
          10       0.65      0.42      0.51       226
          11       0.57      0.53      0.55       380
          12       0.47      0.06      0.11       451
          13       0.00      0.00      0.00       395
          14       0.71      0.83      0.77       245

    accuracy                           0.99    505344
   macro avg       0.63   

  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

        -100       1.00      1.00      1.00    177928
           0       0.99      0.99      0.99    248269
           1       1.00      1.00      1.00     73507
           2       0.90      0.88      0.89      1282
           3       0.57      0.59      0.58       463
           4       0.52      0.20      0.28       240
           5       0.64      0.58      0.61       282
           6       0.31      0.05      0.09       203
           7       0.20      0.12      0.15       790
           8       0.67      0.94      0.78       435
           9       0.31      0.02      0.04       248
          10       0.41      0.58      0.48       226
          11       0.56      0.50      0.53       380
          12       0.43      0.08      0.13       451
          13       0.21      0.01      0.01       395
          14       0.71      0.84      0.77       245

    accuracy                           0.99    505344
   macro avg       0.59   

KeyboardInterrupt: 

In [61]:
p = trainer.predict(encoded_dataset['test'])

              precision    recall  f1-score   support

        -100       1.00      1.00      1.00    177928
           0       0.99      0.99      0.99    248269
           1       1.00      1.00      1.00     73507
           2       0.87      0.90      0.89      1282
           3       0.54      0.56      0.55       463
           4       0.58      0.65      0.62       240
           5       0.62      0.52      0.57       282
           6       0.37      0.32      0.34       203
           7       0.35      0.12      0.18       790
           8       0.68      0.87      0.76       435
           9       0.41      0.53      0.46       248
          10       0.53      0.59      0.56       226
          11       0.55      0.54      0.54       380
          12       0.31      0.10      0.15       451
          13       0.08      0.02      0.03       395
          14       0.73      0.77      0.75       245

    accuracy                           0.99    505344
   macro avg       0.60   

In [62]:
preds = np.argmax(p.predictions, axis=2)

In [63]:
# Ignore predictions where label == -100, padding
preds[np.where(p.label_ids == -100)] = -100
report = classification_report(np.reshape(p.label_ids, -1), np.reshape(preds, -1), output_dict=True)

In [64]:
r_label_map = {v:k for k,v in tokenizer.label_map.items()}
for key in report.keys():
    if key.isdigit():
        cui = r_label_map.get(int(key), key)
    else:
        cui = key
    
    name = cdb.get_name(cui)
    print(name)
    print(report[key])
    print()

-100
{'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 177928}

O
{'precision': 0.9894611374426603, 'recall': 0.9930639749626413, 'f1-score': 0.9912592824892148, 'support': 248269}

X
{'precision': 0.9999591864499013, 'recall': 0.9999319792672806, 'f1-score': 0.9999455826735232, 'support': 73507}

Patients
{'precision': 0.8744326777609682, 'recall': 0.9017160686427457, 'f1-score': 0.8878648233486942, 'support': 1282}

Study
{'precision': 0.54375, 'recall': 0.5637149028077754, 'f1-score': 0.5535524920466596, 'support': 463}

Levels
{'precision': 0.5814814814814815, 'recall': 0.6541666666666667, 'f1-score': 0.615686274509804, 'support': 240}

Analysis
{'precision': 0.6244725738396625, 'recall': 0.524822695035461, 'f1-score': 0.5703275529865125, 'support': 282}

Result
{'precision': 0.3693181818181818, 'recall': 0.32019704433497537, 'f1-score': 0.34300791556728233, 'support': 203}

Finding
{'precision': 0.3467153284671533, 'recall': 0.12025316455696203, 'f1-score': 0.178571428

## Test MedCAT on the same data

In [72]:
data = json.load(open(DATA_PATH))

In [73]:
cnts = {}
for p in data['projects']:
    for d in p['documents']:
        for a in d['annotations']:
            cnts[a['cui']] = cnts.get(a['cui'], 0) + 1
cnts

{'C0030705': 5897,
 'C2603343': 2140,
 'C0441889': 1041,
 'C0936012': 1192,
 'C1511726': 1001,
 'C1274040': 1044,
 'C0332281': 1300,
 'C0243095': 1925,
 'C0441833': 1203,
 'C0017337': 1032,
 'C0087111': 1662,
 'C0392762': 1236,
 'C0205556': 1248}

In [74]:
cat.config.linking['filters']['cuis'] = set(cnts.keys())

In [68]:
# Print stats before training
fp, fn, tp, p, r, f1, cui_counts, examples = cat._print_stats(data)

HBox(children=(FloatProgress(value=0.0, description='Stats project', max=1.0, style=ProgressStyle(description_…

HBox(children=(FloatProgress(value=0.0, description='Stats document', max=4392.0, style=ProgressStyle(descript…

Epoch: 0, Prec: 0.5875118259224219, Rec: 0.1416514598540146, F1: 0.22826686270906082

Docs with false positives: unk

Docs with false negatives: unk



False Positives

Associated with                                                        - C0332281             -        639
Genes                                                                  - C0017337             -        319
Analysis                                                               - C0936012             -        319
Groups                                                                 - C0441833             -        141
Study                                                                  - C2603343             -        135
quantitative~concept                                                   - C0392762             -        135
Levels                                                                 - C0441889             -        107
Result                                                                 - C1274040 

In [75]:
np.average(list(f1.values()))

0.17901604244305488

In [76]:
np.average(list(p.values()))

  return array(a, dtype, copy=False, order=order, subok=True)


TypeError: can only concatenate str (not "int") to str

In [77]:
np.average(list(r.values()))

0.14785320773924018

In [None]:
fp, fn, tp, p, r, f1, cui_counts, examples = cat.train_supervised(data_path=DATA_PATH, print_stats=1, nepochs=5, test_size=0.2, devalue_others=True,
                         train_from_false_positives=True)

HBox(children=(FloatProgress(value=0.0, description='Stats project', max=1.0, style=ProgressStyle(description_…

HBox(children=(FloatProgress(value=0.0, description='Stats document', max=904.0, style=ProgressStyle(descripti…

Epoch: 0, Prec: 0.5780952380952381, Rec: 0.13839489284085726, F1: 0.22332597498160411

Docs with false positives: unk

Docs with false negatives: unk



False Positives

Associated with                                                        - C0332281             -        128
Analysis                                                               - C0936012             -         67
Genes                                                                  - C0017337             -         67
quantitative~concept                                                   - C0392762             -         30
Study                                                                  - C2603343             -         30
Groups                                                                 - C0441833             -         22
Levels                                                                 - C0441889             -         21
Data                                                                   - C1511726

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=5.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Project', max=1.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Document', max=3488.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Stats project', max=1.0, style=ProgressStyle(description_…

HBox(children=(FloatProgress(value=0.0, description='Stats document', max=904.0, style=ProgressStyle(descripti…

Epoch: 1, Prec: 0.14211269134730214, Rec: 0.29844961240310075, F1: 0.19254247260425095

Docs with false positives: unk

Docs with false negatives: unk



False Positives

sophisticated                                                          - C0205556             -       2909
sos                                                                    - C0392762             -       1664
Therapeutic procedure                                                  - C0087111             -       1055
Finding                                                                - C0243095             -        905
Patients                                                               - C0030705             -        779
Genes                                                                  - C0017337             -        197
Analysis                                                               - C0936012             -         73
Study                                                                  - C260334

HBox(children=(FloatProgress(value=0.0, description='Project', max=1.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Document', max=3488.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Stats project', max=1.0, style=ProgressStyle(description_…

HBox(children=(FloatProgress(value=0.0, description='Stats document', max=904.0, style=ProgressStyle(descripti…

Epoch: 2, Prec: 0.14255409418752651, Rec: 0.3831242873432155, F1: 0.20779220779220783

Docs with false positives: unk

Docs with false negatives: unk



False Positives

sophisticated                                                          - C0205556             -       3287
sos                                                                    - C0392762             -       1677
Finding                                                                - C0243095             -       1606
Patients                                                               - C0030705             -       1437
Therapeutic procedure                                                  - C0087111             -       1052
Genes                                                                  - C0017337             -        302
Study                                                                  - C2603343             -        161
Result                                                                 - C1274040

HBox(children=(FloatProgress(value=0.0, description='Project', max=1.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Document', max=3488.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=0.0, description='Stats project', max=1.0, style=ProgressStyle(description_…

HBox(children=(FloatProgress(value=0.0, description='Stats document', max=904.0, style=ProgressStyle(descripti…

Epoch: 3, Prec: 0.14423221537073244, Rec: 0.43740022805017104, F1: 0.21693151614545045

Docs with false positives: unk

Docs with false negatives: unk



False Positives

sophisticated                                                          - C0205556             -       3417
Finding                                                                - C0243095             -       1932
sos                                                                    - C0392762             -       1728
Patients                                                               - C0030705             -       1662
Therapeutic procedure                                                  - C0087111             -       1049
Genes                                                                  - C0017337             -        386
Result                                                                 - C1274040             -        281
Study                                                                  - C260334

HBox(children=(FloatProgress(value=0.0, description='Project', max=1.0, style=ProgressStyle(description_width=…

HBox(children=(FloatProgress(value=0.0, description='Document', max=3488.0, style=ProgressStyle(description_wi…

In [22]:
np.average(list(f1.values()))

0.43076030010032723

In [None]:
np.average(list(p.values()))

In [None]:
np.average(list(r.values()))

In [20]:
cat.config.linking['similarity_threshold'] = 0.2