### Install & Import Packages

In [2]:
%%capture
!pip install "flair" -q

In [3]:
from google.colab import drive
drive.mount('/content/drive')
import flair
from flair.data import Sentence
from flair.datasets import ColumnCorpus
from flair.embeddings import (
    WordEmbeddings, TransformerWordEmbeddings, StackedEmbeddings
)
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
flair.__version__

'0.12.2'

### Import Data

In [4]:
DATA_PATH = "/content/drive/Shareddrives/CIS522-Project/data"
MODEL_PATH = "/content/drive/Shareddrives/CIS522-Project/models"

In [5]:
# Import data into flair using ColumnCorpus
corpus = ColumnCorpus(DATA_PATH, {0: "text", 1: "ner"}, train_file="flair_ner_train_augmented.txt", test_file="flair_ner_test.txt")
# Needed for model initialization
tag_dictionary = corpus.make_label_dictionary(label_type="ner", add_unk=False)
print(tag_dictionary.get_items())

2023-04-15 06:02:07,568 Reading data from /content/drive/Shareddrives/CIS522-Project/data
2023-04-15 06:02:07,572 Train: /content/drive/Shareddrives/CIS522-Project/data/flair_ner_train_augmented.txt
2023-04-15 06:02:07,576 Dev: None
2023-04-15 06:02:07,577 Test: /content/drive/Shareddrives/CIS522-Project/data/flair_ner_test.txt
2023-04-15 06:02:34,037 Computing label dictionary. Progress:


2776it [00:00, 24020.36it/s]

2023-04-15 06:02:34,197 Dictionary created for label 'ner' with 9 values: Drug (seen 5597 times), ADE (seen 4556 times), Reason (seen 1281 times), Strength (seen 964 times), Route (seen 816 times), Frequency (seen 713 times), Form (seen 620 times), Dosage (seen 595 times), Duration (seen 127 times)
['Drug', 'ADE', 'Reason', 'Strength', 'Route', 'Frequency', 'Form', 'Dosage', 'Duration']





### Initialize Weight Dictionary
This is the weight dictionary used by the loss function. The weight for a given entity is set to the ratio between the frequency of the most represented entity and the frequency of the given entity.

In [None]:
weight_dict = {
    'Drug': 5597/5597,
    'Strength': 5597/964,
    'Form': 5597/620,
    'Frequency': 5597/713,
    'Route': 5597/816,
    'Dosage': 5597/595,
    'Reason': 5597/1281,
    'Duration': 5597/127,
    'ADE': 5597/4556
}

In [7]:
tf_tagger = SequenceTagger.load(f"{MODEL_PATH}/taggers/clinicalbert-crf/final-model.pt")
tf_tagger.weight_dict = weight_dict

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

2023-04-15 06:02:59,045 SequenceTagger predicts: Dictionary with 39 tags: O, S-Drug, B-Drug, E-Drug, I-Drug, S-Strength, B-Strength, E-Strength, I-Strength, S-Form, B-Form, E-Form, I-Form, S-Frequency, B-Frequency, E-Frequency, I-Frequency, S-Route, B-Route, E-Route, I-Route, S-Dosage, B-Dosage, E-Dosage, I-Dosage, S-Reason, B-Reason, E-Reason, I-Reason, S-Duration, B-Duration, E-Duration, I-Duration, S-ADE, B-ADE, E-ADE, I-ADE, <START>, <STOP>


### Fine-Tune & Evaluate Model

In [8]:
# Initialize trainer
trainer = ModelTrainer(tf_tagger, corpus)

# Train on corpus
trainer.train(
    base_path=f"{MODEL_PATH}/taggers/clinicalbert-crf-augmented",
    train_with_dev=False,
    max_epochs=10,
    learning_rate=0.005,
    mini_batch_size=16,
    embeddings_storage_mode='none'
)



2023-04-15 06:03:00,934 ----------------------------------------------------------------------------------------------------
2023-04-15 06:03:00,939 Model: "SequenceTagger(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): TransformerWordEmbeddings(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(28997, 768)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_fea

100%|██████████| 20/20 [00:13<00:00,  1.52it/s]


2023-04-15 06:06:59,966 Evaluating as a multi-label problem: False
2023-04-15 06:06:59,997 DEV : loss 0.1432962864637375 - f1-score (micro avg)  0.8595
2023-04-15 06:07:00,008 BAD EPOCHS (no improvement): 0
2023-04-15 06:07:00,012 saving best model
2023-04-15 06:07:01,518 ----------------------------------------------------------------------------------------------------
2023-04-15 06:07:25,119 epoch 2 - iter 17/174 - loss 0.24036553 - time (sec): 23.60 - samples/sec: 471.85 - lr: 0.005000
2023-04-15 06:07:46,485 epoch 2 - iter 34/174 - loss 0.23766284 - time (sec): 44.96 - samples/sec: 482.67 - lr: 0.005000
2023-04-15 06:08:10,454 epoch 2 - iter 51/174 - loss 0.23116425 - time (sec): 68.93 - samples/sec: 478.98 - lr: 0.005000
2023-04-15 06:08:31,643 epoch 2 - iter 68/174 - loss 0.23109893 - time (sec): 90.12 - samples/sec: 486.83 - lr: 0.005000
2023-04-15 06:08:53,295 epoch 2 - iter 85/174 - loss 0.23258418 - time (sec): 111.77 - samples/sec: 488.28 - lr: 0.005000
2023-04-15 06:09:17,

100%|██████████| 20/20 [00:11<00:00,  1.77it/s]


2023-04-15 06:10:59,996 Evaluating as a multi-label problem: False
2023-04-15 06:11:00,022 DEV : loss 0.11918845027685165 - f1-score (micro avg)  0.8898
2023-04-15 06:11:00,033 BAD EPOCHS (no improvement): 0
2023-04-15 06:11:00,038 saving best model
2023-04-15 06:11:01,614 ----------------------------------------------------------------------------------------------------
2023-04-15 06:11:22,930 epoch 3 - iter 17/174 - loss 0.21704233 - time (sec): 21.31 - samples/sec: 479.72 - lr: 0.005000
2023-04-15 06:11:46,091 epoch 3 - iter 34/174 - loss 0.21252368 - time (sec): 44.47 - samples/sec: 463.02 - lr: 0.005000
2023-04-15 06:12:06,636 epoch 3 - iter 51/174 - loss 0.21298039 - time (sec): 65.02 - samples/sec: 473.65 - lr: 0.005000
2023-04-15 06:12:27,990 epoch 3 - iter 68/174 - loss 0.20666109 - time (sec): 86.37 - samples/sec: 479.95 - lr: 0.005000
2023-04-15 06:12:51,182 epoch 3 - iter 85/174 - loss 0.20315278 - time (sec): 109.56 - samples/sec: 477.45 - lr: 0.005000
2023-04-15 06:13:14

100%|██████████| 20/20 [00:10<00:00,  1.83it/s]

2023-04-15 06:15:02,409 Evaluating as a multi-label problem: False
2023-04-15 06:15:02,445 DEV : loss 0.10443847626447678 - f1-score (micro avg)  0.8976
2023-04-15 06:15:02,464 BAD EPOCHS (no improvement): 0
2023-04-15 06:15:02,470 saving best model





2023-04-15 06:15:04,292 ----------------------------------------------------------------------------------------------------
2023-04-15 06:15:26,683 epoch 4 - iter 17/174 - loss 0.17012507 - time (sec): 22.38 - samples/sec: 470.38 - lr: 0.005000
2023-04-15 06:15:50,540 epoch 4 - iter 34/174 - loss 0.17791269 - time (sec): 46.23 - samples/sec: 462.75 - lr: 0.005000
2023-04-15 06:16:14,177 epoch 4 - iter 51/174 - loss 0.17282138 - time (sec): 69.87 - samples/sec: 468.66 - lr: 0.005000
2023-04-15 06:16:35,486 epoch 4 - iter 68/174 - loss 0.17414256 - time (sec): 91.18 - samples/sec: 471.10 - lr: 0.005000
2023-04-15 06:16:57,038 epoch 4 - iter 85/174 - loss 0.17677838 - time (sec): 112.73 - samples/sec: 474.55 - lr: 0.005000
2023-04-15 06:17:18,496 epoch 4 - iter 102/174 - loss 0.17544181 - time (sec): 134.19 - samples/sec: 478.37 - lr: 0.005000
2023-04-15 06:17:40,828 epoch 4 - iter 119/174 - loss 0.17562583 - time (sec): 156.52 - samples/sec: 477.74 - lr: 0.005000
2023-04-15 06:18:04,252

100%|██████████| 20/20 [00:12<00:00,  1.65it/s]

2023-04-15 06:19:05,939 Evaluating as a multi-label problem: False
2023-04-15 06:19:05,974 DEV : loss 0.08795062452554703 - f1-score (micro avg)  0.915
2023-04-15 06:19:05,991 BAD EPOCHS (no improvement): 0
2023-04-15 06:19:05,998 saving best model





2023-04-15 06:19:07,684 ----------------------------------------------------------------------------------------------------
2023-04-15 06:19:29,587 epoch 5 - iter 17/174 - loss 0.15494907 - time (sec): 21.89 - samples/sec: 487.90 - lr: 0.005000
2023-04-15 06:19:51,644 epoch 5 - iter 34/174 - loss 0.15975504 - time (sec): 43.95 - samples/sec: 490.73 - lr: 0.005000
2023-04-15 06:20:14,822 epoch 5 - iter 51/174 - loss 0.15744556 - time (sec): 67.13 - samples/sec: 480.94 - lr: 0.005000
2023-04-15 06:20:37,099 epoch 5 - iter 68/174 - loss 0.15730954 - time (sec): 89.41 - samples/sec: 479.59 - lr: 0.005000
2023-04-15 06:21:00,725 epoch 5 - iter 85/174 - loss 0.15775008 - time (sec): 113.03 - samples/sec: 476.48 - lr: 0.005000
2023-04-15 06:21:23,667 epoch 5 - iter 102/174 - loss 0.15969113 - time (sec): 135.97 - samples/sec: 476.16 - lr: 0.005000
2023-04-15 06:21:47,298 epoch 5 - iter 119/174 - loss 0.15863450 - time (sec): 159.60 - samples/sec: 474.67 - lr: 0.005000
2023-04-15 06:22:10,030

100%|██████████| 20/20 [00:11<00:00,  1.76it/s]

2023-04-15 06:23:09,233 Evaluating as a multi-label problem: False
2023-04-15 06:23:09,257 DEV : loss 0.09466564655303955 - f1-score (micro avg)  0.9111
2023-04-15 06:23:09,267 BAD EPOCHS (no improvement): 1
2023-04-15 06:23:09,271 ----------------------------------------------------------------------------------------------------





2023-04-15 06:23:35,224 epoch 6 - iter 17/174 - loss 0.13131950 - time (sec): 25.95 - samples/sec: 431.03 - lr: 0.005000
2023-04-15 06:24:00,317 epoch 6 - iter 34/174 - loss 0.14246694 - time (sec): 51.04 - samples/sec: 437.50 - lr: 0.005000
2023-04-15 06:24:21,940 epoch 6 - iter 51/174 - loss 0.14429196 - time (sec): 72.67 - samples/sec: 449.88 - lr: 0.005000
2023-04-15 06:24:43,263 epoch 6 - iter 68/174 - loss 0.14661832 - time (sec): 93.99 - samples/sec: 460.97 - lr: 0.005000
2023-04-15 06:25:04,929 epoch 6 - iter 85/174 - loss 0.14519449 - time (sec): 115.65 - samples/sec: 467.39 - lr: 0.005000
2023-04-15 06:25:27,311 epoch 6 - iter 102/174 - loss 0.14425636 - time (sec): 138.04 - samples/sec: 466.67 - lr: 0.005000
2023-04-15 06:25:47,867 epoch 6 - iter 119/174 - loss 0.14687581 - time (sec): 158.59 - samples/sec: 473.07 - lr: 0.005000
2023-04-15 06:26:09,722 epoch 6 - iter 136/174 - loss 0.14579106 - time (sec): 180.45 - samples/sec: 477.27 - lr: 0.005000
2023-04-15 06:26:31,361 e

100%|██████████| 20/20 [00:09<00:00,  2.00it/s]


2023-04-15 06:27:07,597 Evaluating as a multi-label problem: False
2023-04-15 06:27:07,634 DEV : loss 0.07021188735961914 - f1-score (micro avg)  0.9299
2023-04-15 06:27:07,653 BAD EPOCHS (no improvement): 0
2023-04-15 06:27:07,660 saving best model
2023-04-15 06:27:09,521 ----------------------------------------------------------------------------------------------------
2023-04-15 06:27:31,703 epoch 7 - iter 17/174 - loss 0.15356296 - time (sec): 22.18 - samples/sec: 480.05 - lr: 0.005000
2023-04-15 06:27:55,127 epoch 7 - iter 34/174 - loss 0.15251524 - time (sec): 45.60 - samples/sec: 469.02 - lr: 0.005000
2023-04-15 06:28:16,795 epoch 7 - iter 51/174 - loss 0.14460939 - time (sec): 67.27 - samples/sec: 476.50 - lr: 0.005000
2023-04-15 06:28:38,403 epoch 7 - iter 68/174 - loss 0.14744792 - time (sec): 88.88 - samples/sec: 483.83 - lr: 0.005000
2023-04-15 06:29:02,294 epoch 7 - iter 85/174 - loss 0.14503115 - time (sec): 112.77 - samples/sec: 480.91 - lr: 0.005000
2023-04-15 06:29:22

100%|██████████| 20/20 [00:11<00:00,  1.76it/s]


2023-04-15 06:31:11,276 Evaluating as a multi-label problem: False
2023-04-15 06:31:11,301 DEV : loss 0.057655636221170425 - f1-score (micro avg)  0.9452
2023-04-15 06:31:11,314 BAD EPOCHS (no improvement): 0
2023-04-15 06:31:11,318 saving best model
2023-04-15 06:31:12,831 ----------------------------------------------------------------------------------------------------
2023-04-15 06:31:37,710 epoch 8 - iter 17/174 - loss 0.13058652 - time (sec): 24.88 - samples/sec: 446.07 - lr: 0.005000
2023-04-15 06:31:59,636 epoch 8 - iter 34/174 - loss 0.13241014 - time (sec): 46.80 - samples/sec: 464.33 - lr: 0.005000
2023-04-15 06:32:22,395 epoch 8 - iter 51/174 - loss 0.12935195 - time (sec): 69.56 - samples/sec: 469.97 - lr: 0.005000
2023-04-15 06:32:44,051 epoch 8 - iter 68/174 - loss 0.12683738 - time (sec): 91.22 - samples/sec: 476.47 - lr: 0.005000
2023-04-15 06:33:06,635 epoch 8 - iter 85/174 - loss 0.12957865 - time (sec): 113.80 - samples/sec: 474.12 - lr: 0.005000
2023-04-15 06:33:3

100%|██████████| 20/20 [00:11<00:00,  1.76it/s]


2023-04-15 06:35:15,139 Evaluating as a multi-label problem: False
2023-04-15 06:35:15,162 DEV : loss 0.05484052002429962 - f1-score (micro avg)  0.9435
2023-04-15 06:35:15,172 BAD EPOCHS (no improvement): 1
2023-04-15 06:35:15,177 ----------------------------------------------------------------------------------------------------
2023-04-15 06:35:40,669 epoch 9 - iter 17/174 - loss 0.11509604 - time (sec): 25.49 - samples/sec: 437.41 - lr: 0.005000
2023-04-15 06:36:02,768 epoch 9 - iter 34/174 - loss 0.11566064 - time (sec): 47.59 - samples/sec: 468.53 - lr: 0.005000
2023-04-15 06:36:24,100 epoch 9 - iter 51/174 - loss 0.12188893 - time (sec): 68.92 - samples/sec: 476.69 - lr: 0.005000
2023-04-15 06:36:46,415 epoch 9 - iter 68/174 - loss 0.12281571 - time (sec): 91.24 - samples/sec: 476.69 - lr: 0.005000
2023-04-15 06:37:06,933 epoch 9 - iter 85/174 - loss 0.12178176 - time (sec): 111.75 - samples/sec: 485.82 - lr: 0.005000
2023-04-15 06:37:29,578 epoch 9 - iter 102/174 - loss 0.12298

100%|██████████| 20/20 [00:12<00:00,  1.66it/s]

2023-04-15 06:39:15,401 Evaluating as a multi-label problem: False
2023-04-15 06:39:15,439 DEV : loss 0.04831854626536369 - f1-score (micro avg)  0.9534
2023-04-15 06:39:15,458 BAD EPOCHS (no improvement): 0
2023-04-15 06:39:15,465 saving best model





2023-04-15 06:39:17,116 ----------------------------------------------------------------------------------------------------
2023-04-15 06:39:39,246 epoch 10 - iter 17/174 - loss 0.11980080 - time (sec): 22.12 - samples/sec: 470.49 - lr: 0.005000
2023-04-15 06:40:02,914 epoch 10 - iter 34/174 - loss 0.11768876 - time (sec): 45.79 - samples/sec: 466.57 - lr: 0.005000
2023-04-15 06:40:24,753 epoch 10 - iter 51/174 - loss 0.11892411 - time (sec): 67.63 - samples/sec: 474.38 - lr: 0.005000
2023-04-15 06:40:47,022 epoch 10 - iter 68/174 - loss 0.11972931 - time (sec): 89.90 - samples/sec: 478.71 - lr: 0.005000
2023-04-15 06:41:10,527 epoch 10 - iter 85/174 - loss 0.11826973 - time (sec): 113.40 - samples/sec: 477.19 - lr: 0.005000
2023-04-15 06:41:32,287 epoch 10 - iter 102/174 - loss 0.11568852 - time (sec): 135.16 - samples/sec: 478.04 - lr: 0.005000
2023-04-15 06:41:55,783 epoch 10 - iter 119/174 - loss 0.11444865 - time (sec): 158.66 - samples/sec: 476.96 - lr: 0.005000
2023-04-15 06:42

100%|██████████| 20/20 [00:10<00:00,  1.95it/s]


2023-04-15 06:43:15,013 Evaluating as a multi-label problem: False
2023-04-15 06:43:15,039 DEV : loss 0.049334339797496796 - f1-score (micro avg)  0.9528
2023-04-15 06:43:15,050 BAD EPOCHS (no improvement): 1
2023-04-15 06:43:20,496 ----------------------------------------------------------------------------------------------------
2023-04-15 06:43:24,078 SequenceTagger predicts: Dictionary with 39 tags: O, S-Drug, B-Drug, E-Drug, I-Drug, S-Strength, B-Strength, E-Strength, I-Strength, S-Form, B-Form, E-Form, I-Form, S-Frequency, B-Frequency, E-Frequency, I-Frequency, S-Route, B-Route, E-Route, I-Route, S-Dosage, B-Dosage, E-Dosage, I-Dosage, S-Reason, B-Reason, E-Reason, I-Reason, S-Duration, B-Duration, E-Duration, I-Duration, S-ADE, B-ADE, E-ADE, I-ADE, <START>, <STOP>


100%|██████████| 1588/1588 [13:16<00:00,  1.99it/s]


2023-04-15 06:56:42,619 Evaluating as a multi-label problem: False
2023-04-15 06:56:44,785 0.8869	0.9188	0.9026	0.8296
2023-04-15 06:56:44,786 
Results:
- F-score (micro) 0.9026
- F-score (macro) 0.8172
- Accuracy 0.8296

By class:
              precision    recall  f1-score   support

        Drug     0.8731    0.9410    0.9058     61167
    Strength     0.9357    0.9562    0.9458     42957
        Form     0.9178    0.9228    0.9203     41417
   Frequency     0.8551    0.8580    0.8566     36495
       Route     0.9509    0.9424    0.9466     30583
      Dosage     0.9244    0.9455    0.9348     23506
      Reason     0.7222    0.7490    0.7353      9533
         ADE     0.2378    0.6451    0.3475      1299
    Duration     0.7466    0.7790    0.7625      1982

   micro avg     0.8869    0.9188    0.9026    248939
   macro avg     0.7960    0.8599    0.8172    248939
weighted avg     0.8930    0.9188    0.9051    248939

2023-04-15 06:56:44,791 ---------------------------------------

{'test_score': 0.9025737143759063,
 'dev_score_history': [0.8595224875069406,
  0.8897593732512591,
  0.8976290097629009,
  0.9150253235790659,
  0.9111479028697572,
  0.9299435028248587,
  0.945174186179326,
  0.9435140505251206,
  0.953388618816128,
  0.9527872582480091],
 'train_loss_history': [0.2693294886703613,
  0.22516719765072943,
  0.19442447590181217,
  0.17497128857277183,
  0.1607360814800832,
  0.14557645680775588,
  0.13896159373575337,
  0.1322365663109756,
  0.12198050303097219,
  0.11407735211989971],
 'dev_loss_history': [0.1432962864637375,
  0.11918845027685165,
  0.10443847626447678,
  0.08795062452554703,
  0.09466564655303955,
  0.07021188735961914,
  0.057655636221170425,
  0.05484052002429962,
  0.04831854626536369,
  0.049334339797496796]}

In [None]:
# Create example sentence
sentence = Sentence("Patients on 40 mg of Topelfate and Topoxy twice a day for stomachache generally suffer from headache")

# Token level predictions
tf_tagger.predict(sentence, force_token_predictions=True)
print(sentence.to_tagged_string())

# Predict tags and print
tf_tagger.predict(sentence)
print(sentence.to_tagged_string())

Sentence[17]: "Patients on 40 mg of Topelfate and Topoxy twice a day for stomachache generally suffer from headache" → ["40"/B-Strength, "mg"/E-Strength, "Topelfate"/S-Drug, "Topoxy"/S-Drug, "twice"/B-Frequency, "a"/I-Frequency, "day"/E-Frequency, "stomachache"/S-Reason, "headache"/S-ADE]
Sentence[17]: "Patients on 40 mg of Topelfate and Topoxy twice a day for stomachache generally suffer from headache" → ["40 mg"/Strength, "Topelfate"/Drug, "Topoxy"/Drug, "twice a day"/Frequency, "stomachache"/Reason, "headache"/ADE]
