In [1]:
BASE_PATH = '/home/wish/'

In [2]:
import sys
sys.path.insert(0, BASE_PATH + "projects/medflux/")
sys.path.insert(0, BASE_PATH + "projects/MedCAT/")

%load_ext autoreload
%autoreload 2

In [3]:
import json
import pandas as pd
import numpy as np
import os

import datasets

from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments

from medcat.datasets import medcat_ner
from medcat.datasets.tokenizer_ner import TokenizerNER
from medcat.datasets.data_collator import CollateAndPadNER

In [4]:
DATA_PATH = BASE_PATH + "data/mimic/mimic.json"

In [5]:
data = json.load(open(DATA_PATH))

In [6]:
cnts = {}
for p in data['projects']:
    for d in p['documents']:
        for a in d['annotations']:
            cnts[a['cui']] = cnts.get(a['cui'], 0) + 1

### Subset to frequent concepts - for testing

In [7]:
MIN_FREQ = 50
for p in data['projects']:
    for d in p['documents']:
        anns = []
        for a in d['annotations']:
            if cnts[a['cui']] > MIN_FREQ:
                anns.append(a)
        d['annotations'] = anns

In [8]:
DATA_PATH = BASE_PATH + "data/mimic/mimic_only_above_50.json"
json.dump(data, open(DATA_PATH, 'w'))

In [9]:
dataset = datasets.load_dataset(os.path.abspath(medcat_ner.__file__), 
                                data_files=DATA_PATH, 
                                split=datasets.Split.TRAIN)

Using custom data configuration default-337cf20d3b0a77fe


Downloading and preparing dataset med_catner/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/wish/.cache/huggingface/datasets/med_catner/default-337cf20d3b0a77fe/0.0.0/98e55c8f8beecf808ac4eaeb3a37a07e036655af8378f121808fe338b5a86b4a...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset med_catner downloaded and prepared to /home/wish/.cache/huggingface/datasets/med_catner/default-337cf20d3b0a77fe/0.0.0/98e55c8f8beecf808ac4eaeb3a37a07e036655af8378f121808fe338b5a86b4a. Subsequent calls will reuse this data.


In [10]:
hf_tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

In [11]:
id2type = {}
for i in range(hf_tokenizer.vocab_size):
    id2type[i] = 'sub' if hf_tokenizer.convert_ids_to_tokens(i).startswith("##") else 'start'

In [12]:
tokenizer = TokenizerNER(hf_tokenizer, id2type=id2type)

In [13]:
encoded_dataset = dataset.map(
        lambda examples: tokenizer.encode(examples, use_subwords=True),
        batched=True,
        remove_columns=['ent_cuis', 'ent_ends', 'ent_starts', 'text'])

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [14]:
model = AutoModelForTokenClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(tokenizer.label_map))

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint 

In [15]:
encoded_dataset = encoded_dataset.train_test_split(test_size = 0.2)

In [16]:
collate_fn = CollateAndPadNER(hf_tokenizer.pad_token_id)

In [17]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=10,              # total number of training epochs
    per_device_train_batch_size=4,  # batch size per device during training
    per_device_eval_batch_size=4,   # batch size for evaluation
    weight_decay=0.1431478776404838,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,
    eval_steps=100,
    warmup_steps=15,
    learning_rate= 4.4670352057797207e-05,
    eval_accumulation_steps=1,
    do_eval=True,
    evaluation_strategy='steps',
    load_best_model_at_end=True,
)

In [18]:
from sklearn.metrics import classification_report
def metrics(p):
    preds = np.argmax(p.predictions, axis=2)
    print(preds)
    print(classification_report(np.reshape(p.label_ids, -1), np.reshape(preds, -1)))
    return {'none': 0}

In [19]:
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=encoded_dataset['train'],         # training dataset
    eval_dataset=encoded_dataset['test'],             # evaluation dataset
    compute_metrics=metrics,
    data_collator=collate_fn,
    tokenizer=None
)

In [20]:
trainer.train()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwish[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.32 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Step,Training Loss,Validation Loss,None,Runtime,Samples Per Second
100,No log,0.020473,0,0.9184,127.396
200,0.140100,0.011691,0,0.9258,126.371
300,0.140100,0.005034,0,0.9303,125.762
400,0.005400,0.003786,0,0.9297,125.851
500,0.005400,0.00361,0,0.938,124.729
600,0.001700,0.003689,0,0.9444,123.883
700,0.001700,0.003948,0,0.9387,124.647
800,0.000800,0.003729,0,0.9449,123.817
900,0.000800,0.00371,0,0.9334,125.349
1000,0.000400,0.003598,0,0.9301,125.796


[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.69      1.00      0.82     34257
           1       0.89      1.00      0.94      9127
           2       0.00      0.00      0.00        30
           3       0.00      0.00      0.00        18
           4       0.00      0.00      0.00        27
           5       0.00      0.00      0.00        19
           6       0.00      0.00      0.00        22
           7       0.00      0.00      0.00        10

    accuracy                           0.72     59904
   macro avg       0.18      0.22      0.20     59904
weighted avg       0.53      0.72      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.69      1.00      0.82     34257
           1       0.89      1.00      0.94      9127
           2       0.48      0.80      0.60        30
           3       0.14      0.94      0.24        18
           4       0.50      0.04      0.07        27
           5       1.00      0.11      0.19        19
           6       0.00      0.00      0.00        22
           7       0.00      0.00      0.00        10

    accuracy                           0.72     59904
   macro avg       0.41      0.43      0.32     59904
weighted avg       0.53      0.72      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[1 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.71      1.00      0.83     34257
           1       0.80      1.00      0.89      9127
           2       0.54      1.00      0.70        30
           3       0.24      1.00      0.38        18
           4       0.85      0.81      0.83        27
           5       0.61      0.89      0.72        19
           6       0.61      0.91      0.73        22
           7       1.00      0.50      0.67        10

    accuracy                           0.73     59904
   macro avg       0.59      0.79      0.64     59904
weighted avg       0.53      0.73      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[1 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.70      1.00      0.82     34257
           1       0.86      1.00      0.92      9127
           2       0.56      1.00      0.71        30
           3       0.41      0.61      0.49        18
           4       0.68      0.85      0.75        27
           5       0.70      0.84      0.76        19
           6       0.43      0.91      0.59        22
           7       0.83      1.00      0.91        10

    accuracy                           0.73     59904
   macro avg       0.57      0.80      0.66     59904
weighted avg       0.53      0.73      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[1 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.71      1.00      0.83     34257
           1       0.81      1.00      0.89      9127
           2       0.54      1.00      0.70        30
           3       0.23      0.89      0.37        18
           4       0.72      0.85      0.78        27
           5       0.55      0.89      0.68        19
           6       0.42      0.91      0.57        22
           7       0.83      1.00      0.91        10

    accuracy                           0.73     59904
   macro avg       0.53      0.84      0.64     59904
weighted avg       0.53      0.73      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[1 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.71      1.00      0.83     34257
           1       0.81      1.00      0.90      9127
           2       0.55      1.00      0.71        30
           3       0.31      0.89      0.46        18
           4       0.79      0.85      0.82        27
           5       0.55      0.89      0.68        19
           6       0.42      0.91      0.57        22
           7       0.77      1.00      0.87        10

    accuracy                           0.73     59904
   macro avg       0.54      0.84      0.65     59904
weighted avg       0.53      0.73      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[1 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.69      1.00      0.82     34257
           1       0.88      1.00      0.94      9127
           2       0.52      1.00      0.68        30
           3       0.30      0.89      0.45        18
           4       0.85      0.81      0.83        27
           5       0.46      0.89      0.61        19
           6       0.48      0.91      0.62        22
           7       0.83      1.00      0.91        10

    accuracy                           0.73     59904
   macro avg       0.56      0.83      0.65     59904
weighted avg       0.53      0.73      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[1 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.70      1.00      0.82     34257
           1       0.84      1.00      0.91      9127
           2       0.49      1.00      0.66        30
           3       0.31      0.89      0.46        18
           4       0.82      0.85      0.84        27
           5       0.44      0.89      0.59        19
           6       0.44      0.91      0.60        22
           7       0.83      1.00      0.91        10

    accuracy                           0.73     59904
   macro avg       0.54      0.84      0.64     59904
weighted avg       0.53      0.73      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[1 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.70      1.00      0.82     34257
           1       0.86      1.00      0.92      9127
           2       0.50      1.00      0.67        30
           3       0.30      0.89      0.44        18
           4       0.92      0.85      0.88        27
           5       0.52      0.89      0.65        19
           6       0.47      0.91      0.62        22
           7       0.83      1.00      0.91        10

    accuracy                           0.73     59904
   macro avg       0.56      0.84      0.66     59904
weighted avg       0.53      0.73      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[1 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.70      1.00      0.83     34257
           1       0.83      1.00      0.91      9127
           2       0.51      1.00      0.67        30
           3       0.30      0.89      0.45        18
           4       0.85      0.85      0.85        27
           5       0.49      0.89      0.63        19
           6       0.43      0.91      0.58        22
           7       0.83      1.00      0.91        10

    accuracy                           0.73     59904
   macro avg       0.55      0.84      0.65     59904
weighted avg       0.53      0.73      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[[1 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
              precision    recall  f1-score   support

        -100       0.00      0.00      0.00     16394
           0       0.70      1.00      0.82     34257
           1       0.86      1.00      0.92      9127
           2       0.47      1.00      0.64        30
           3       0.27      0.89      0.42        18
           4       0.88      0.85      0.87        27
           5       0.44      0.89      0.59        19
           6       0.48      0.91      0.62        22
           7       0.83      1.00      0.91        10

    accuracy                           0.73     59904
   macro avg       0.55      0.84      0.64     59904
weighted avg       0.53      0.73      0.61     59904



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TrainOutput(global_step=1170, training_loss=0.025446575580753833, metrics={'train_runtime': 182.5067, 'train_samples_per_second': 6.411, 'total_flos': 1529650796967360.0, 'epoch': 10.0, 'init_mem_cpu_alloc_delta': 3110973440, 'init_mem_gpu_alloc_delta': 431432192, 'init_mem_cpu_peaked_delta': 418070528, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 877699072, 'train_mem_gpu_alloc_delta': 1736098304, 'train_mem_cpu_peaked_delta': 177430528, 'train_mem_gpu_peaked_delta': 2836218368})