### Install & Import Packages

In [3]:
%%capture
!pip install "flair" -q

In [4]:
from google.colab import drive
drive.mount('/content/drive')
import flair
from flair.data import Sentence
from flair.datasets import ColumnCorpus
from flair.embeddings import (
    WordEmbeddings, TransformerWordEmbeddings, StackedEmbeddings
)
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
flair.__version__

'0.12.2'

### Import Data

In [5]:
DATA_PATH = "/content/drive/Shareddrives/CIS522-Project/data"
MODEL_PATH = "/content/drive/Shareddrives/CIS522-Project/models"

In [6]:
# Import data into flair using ColumnCorpus
corpus = ColumnCorpus(DATA_PATH, {0: "text", 1: "ner"}, train_file="flair_ner_train_augmented.txt", test_file="flair_ner_test.txt")
# Needed for model initialization
tag_dictionary = corpus.make_label_dictionary(label_type="ner", add_unk=False)
print(tag_dictionary.get_items())

2023-04-18 22:38:59,855 Reading data from /content/drive/Shareddrives/CIS522-Project/data
2023-04-18 22:38:59,862 Train: /content/drive/Shareddrives/CIS522-Project/data/flair_ner_train_augmented.txt
2023-04-18 22:38:59,863 Dev: None
2023-04-18 22:38:59,866 Test: /content/drive/Shareddrives/CIS522-Project/data/flair_ner_test.txt
2023-04-18 22:39:27,841 Computing label dictionary. Progress:


2776it [00:00, 35853.94it/s]

2023-04-18 22:39:27,957 Dictionary created for label 'ner' with 9 values: Drug (seen 5609 times), ADE (seen 4577 times), Reason (seen 1269 times), Strength (seen 972 times), Route (seen 824 times), Frequency (seen 723 times), Form (seen 620 times), Dosage (seen 598 times), Duration (seen 121 times)
['Drug', 'ADE', 'Reason', 'Strength', 'Route', 'Frequency', 'Form', 'Dosage', 'Duration']





### Initialize Weight Dictionary
This is the weight dictionary used by the loss function. The weights for the ADE and Reason entities are set to 5, and for the remaining entities are set to 1.

In [7]:
weight_dict = {
    'Drug': 1,
    'Strength': 1,
    'Form': 1,
    'Frequency': 1,
    'Route': 1,
    'Dosage': 1,
    'Reason': 5,
    'Duration': 1,
    'ADE': 5
}

In [8]:
tf_tagger = SequenceTagger.load(f"{MODEL_PATH}/taggers/clinicalbert-crf/final-model.pt")
tf_tagger.weight_dict = weight_dict

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

2023-04-18 22:39:45,747 SequenceTagger predicts: Dictionary with 39 tags: O, S-Drug, B-Drug, E-Drug, I-Drug, S-Strength, B-Strength, E-Strength, I-Strength, S-Form, B-Form, E-Form, I-Form, S-Frequency, B-Frequency, E-Frequency, I-Frequency, S-Route, B-Route, E-Route, I-Route, S-Dosage, B-Dosage, E-Dosage, I-Dosage, S-Reason, B-Reason, E-Reason, I-Reason, S-Duration, B-Duration, E-Duration, I-Duration, S-ADE, B-ADE, E-ADE, I-ADE, <START>, <STOP>


### Fine-Tune & Evaluate Model

In [9]:
# Initialize trainer
trainer = ModelTrainer(tf_tagger, corpus)

# Train on corpus
trainer.train(
    base_path=f"{MODEL_PATH}/taggers/clinicalbert-crf-augmented-weights",
    train_with_dev=False,
    max_epochs=10,
    learning_rate=0.005,
    mini_batch_size=16,
    embeddings_storage_mode='none'
)



2023-04-18 22:39:47,242 ----------------------------------------------------------------------------------------------------
2023-04-18 22:39:47,246 Model: "SequenceTagger(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): TransformerWordEmbeddings(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(28997, 768)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_fea

100%|██████████| 20/20 [00:10<00:00,  2.00it/s]


2023-04-18 22:43:59,304 Evaluating as a multi-label problem: False
2023-04-18 22:43:59,353 DEV : loss 0.1806982308626175 - f1-score (micro avg)  0.8479
2023-04-18 22:43:59,371 BAD EPOCHS (no improvement): 0
2023-04-18 22:43:59,379 saving best model
2023-04-18 22:44:03,933 ----------------------------------------------------------------------------------------------------
2023-04-18 22:44:27,991 epoch 2 - iter 17/174 - loss 0.23779200 - time (sec): 24.06 - samples/sec: 441.09 - lr: 0.005000
2023-04-18 22:44:51,723 epoch 2 - iter 34/174 - loss 0.23879142 - time (sec): 47.79 - samples/sec: 442.48 - lr: 0.005000
2023-04-18 22:45:13,410 epoch 2 - iter 51/174 - loss 0.23540799 - time (sec): 69.47 - samples/sec: 453.23 - lr: 0.005000
2023-04-18 22:45:43,675 epoch 2 - iter 68/174 - loss 0.22577977 - time (sec): 99.74 - samples/sec: 438.70 - lr: 0.005000
2023-04-18 22:46:07,189 epoch 2 - iter 85/174 - loss 0.22348648 - time (sec): 123.25 - samples/sec: 441.52 - lr: 0.005000
2023-04-18 22:46:33,

100%|██████████| 20/20 [00:10<00:00,  1.91it/s]

2023-04-18 22:48:22,443 Evaluating as a multi-label problem: False





2023-04-18 22:48:22,492 DEV : loss 0.14165449142456055 - f1-score (micro avg)  0.8684
2023-04-18 22:48:22,508 BAD EPOCHS (no improvement): 0
2023-04-18 22:48:22,517 saving best model
2023-04-18 22:48:24,485 ----------------------------------------------------------------------------------------------------
2023-04-18 22:48:48,231 epoch 3 - iter 17/174 - loss 0.18690398 - time (sec): 23.74 - samples/sec: 451.92 - lr: 0.005000
2023-04-18 22:49:13,855 epoch 3 - iter 34/174 - loss 0.19966754 - time (sec): 49.36 - samples/sec: 433.64 - lr: 0.005000
2023-04-18 22:49:37,719 epoch 3 - iter 51/174 - loss 0.20147585 - time (sec): 73.23 - samples/sec: 443.58 - lr: 0.005000
2023-04-18 22:50:02,651 epoch 3 - iter 68/174 - loss 0.20135794 - time (sec): 98.16 - samples/sec: 440.54 - lr: 0.005000
2023-04-18 22:50:28,381 epoch 3 - iter 85/174 - loss 0.19779455 - time (sec): 123.89 - samples/sec: 439.97 - lr: 0.005000
2023-04-18 22:50:52,147 epoch 3 - iter 102/174 - loss 0.19749803 - time (sec): 147.65 

100%|██████████| 20/20 [00:10<00:00,  1.86it/s]

2023-04-18 22:52:45,462 Evaluating as a multi-label problem: False
2023-04-18 22:52:45,488 DEV : loss 0.12228060513734818 - f1-score (micro avg)  0.8923
2023-04-18 22:52:45,499 BAD EPOCHS (no improvement): 0
2023-04-18 22:52:45,503 saving best model





2023-04-18 22:52:47,137 ----------------------------------------------------------------------------------------------------
2023-04-18 22:53:11,629 epoch 4 - iter 17/174 - loss 0.18977465 - time (sec): 24.49 - samples/sec: 444.19 - lr: 0.005000
2023-04-18 22:53:34,924 epoch 4 - iter 34/174 - loss 0.17957257 - time (sec): 47.78 - samples/sec: 452.55 - lr: 0.005000
2023-04-18 22:53:57,845 epoch 4 - iter 51/174 - loss 0.17786317 - time (sec): 70.71 - samples/sec: 454.15 - lr: 0.005000
2023-04-18 22:54:20,974 epoch 4 - iter 68/174 - loss 0.17952308 - time (sec): 93.83 - samples/sec: 454.74 - lr: 0.005000
2023-04-18 22:54:43,133 epoch 4 - iter 85/174 - loss 0.17718851 - time (sec): 115.99 - samples/sec: 459.78 - lr: 0.005000
2023-04-18 22:55:08,724 epoch 4 - iter 102/174 - loss 0.17821750 - time (sec): 141.58 - samples/sec: 455.66 - lr: 0.005000
2023-04-18 22:55:34,260 epoch 4 - iter 119/174 - loss 0.17611427 - time (sec): 167.12 - samples/sec: 450.98 - lr: 0.005000
2023-04-18 22:55:57,415

100%|██████████| 20/20 [00:11<00:00,  1.74it/s]

2023-04-18 22:57:06,053 Evaluating as a multi-label problem: False
2023-04-18 22:57:06,079 DEV : loss 0.11125482618808746 - f1-score (micro avg)  0.9022
2023-04-18 22:57:06,090 BAD EPOCHS (no improvement): 0
2023-04-18 22:57:06,096 saving best model





2023-04-18 22:57:08,016 ----------------------------------------------------------------------------------------------------
2023-04-18 22:57:34,663 epoch 5 - iter 17/174 - loss 0.16450350 - time (sec): 26.63 - samples/sec: 434.68 - lr: 0.005000
2023-04-18 22:57:58,364 epoch 5 - iter 34/174 - loss 0.16074829 - time (sec): 50.33 - samples/sec: 445.11 - lr: 0.005000
2023-04-18 22:58:22,665 epoch 5 - iter 51/174 - loss 0.16392104 - time (sec): 74.63 - samples/sec: 442.14 - lr: 0.005000
2023-04-18 22:58:47,836 epoch 5 - iter 68/174 - loss 0.16046323 - time (sec): 99.81 - samples/sec: 440.28 - lr: 0.005000
2023-04-18 22:59:09,319 epoch 5 - iter 85/174 - loss 0.15960699 - time (sec): 121.29 - samples/sec: 451.79 - lr: 0.005000
2023-04-18 22:59:34,324 epoch 5 - iter 102/174 - loss 0.16063591 - time (sec): 146.29 - samples/sec: 446.87 - lr: 0.005000
2023-04-18 22:59:56,293 epoch 5 - iter 119/174 - loss 0.16198220 - time (sec): 168.26 - samples/sec: 449.34 - lr: 0.005000
2023-04-18 23:00:23,738

100%|██████████| 20/20 [00:11<00:00,  1.76it/s]

2023-04-18 23:01:28,030 Evaluating as a multi-label problem: False
2023-04-18 23:01:28,058 DEV : loss 0.10138115286827087 - f1-score (micro avg)  0.9159
2023-04-18 23:01:28,069 BAD EPOCHS (no improvement): 0
2023-04-18 23:01:28,075 saving best model





2023-04-18 23:01:29,615 ----------------------------------------------------------------------------------------------------
2023-04-18 23:01:54,524 epoch 6 - iter 17/174 - loss 0.16089711 - time (sec): 24.90 - samples/sec: 441.72 - lr: 0.005000
2023-04-18 23:02:17,912 epoch 6 - iter 34/174 - loss 0.14959095 - time (sec): 48.29 - samples/sec: 449.69 - lr: 0.005000
2023-04-18 23:02:41,944 epoch 6 - iter 51/174 - loss 0.15020033 - time (sec): 72.32 - samples/sec: 446.28 - lr: 0.005000
2023-04-18 23:03:05,439 epoch 6 - iter 68/174 - loss 0.14860578 - time (sec): 95.82 - samples/sec: 445.98 - lr: 0.005000
2023-04-18 23:03:28,950 epoch 6 - iter 85/174 - loss 0.14916982 - time (sec): 119.33 - samples/sec: 445.83 - lr: 0.005000
2023-04-18 23:03:53,233 epoch 6 - iter 102/174 - loss 0.14799946 - time (sec): 143.61 - samples/sec: 442.86 - lr: 0.005000
2023-04-18 23:04:18,623 epoch 6 - iter 119/174 - loss 0.14830934 - time (sec): 169.00 - samples/sec: 443.15 - lr: 0.005000
2023-04-18 23:04:46,609

100%|██████████| 20/20 [00:11<00:00,  1.72it/s]

2023-04-18 23:05:51,335 Evaluating as a multi-label problem: False
2023-04-18 23:05:51,362 DEV : loss 0.09453363716602325 - f1-score (micro avg)  0.9215
2023-04-18 23:05:51,380 BAD EPOCHS (no improvement): 0
2023-04-18 23:05:51,390 saving best model





2023-04-18 23:05:52,920 ----------------------------------------------------------------------------------------------------
2023-04-18 23:06:19,298 epoch 7 - iter 17/174 - loss 0.13701348 - time (sec): 26.37 - samples/sec: 432.23 - lr: 0.005000
2023-04-18 23:06:41,742 epoch 7 - iter 34/174 - loss 0.13636363 - time (sec): 48.81 - samples/sec: 454.26 - lr: 0.005000
2023-04-18 23:07:06,196 epoch 7 - iter 51/174 - loss 0.13559633 - time (sec): 73.27 - samples/sec: 453.14 - lr: 0.005000
2023-04-18 23:07:31,332 epoch 7 - iter 68/174 - loss 0.13180796 - time (sec): 98.40 - samples/sec: 451.52 - lr: 0.005000
2023-04-18 23:07:55,200 epoch 7 - iter 85/174 - loss 0.13449468 - time (sec): 122.27 - samples/sec: 446.81 - lr: 0.005000
2023-04-18 23:08:19,536 epoch 7 - iter 102/174 - loss 0.13499529 - time (sec): 146.61 - samples/sec: 445.77 - lr: 0.005000
2023-04-18 23:08:42,724 epoch 7 - iter 119/174 - loss 0.13566471 - time (sec): 169.80 - samples/sec: 447.20 - lr: 0.005000
2023-04-18 23:09:05,960

100%|██████████| 20/20 [00:11<00:00,  1.73it/s]

2023-04-18 23:10:11,573 Evaluating as a multi-label problem: False
2023-04-18 23:10:11,599 DEV : loss 0.09184014052152634 - f1-score (micro avg)  0.9266
2023-04-18 23:10:11,613 BAD EPOCHS (no improvement): 0
2023-04-18 23:10:11,618 saving best model





2023-04-18 23:10:13,145 ----------------------------------------------------------------------------------------------------
2023-04-18 23:10:40,609 epoch 8 - iter 17/174 - loss 0.12808948 - time (sec): 27.46 - samples/sec: 408.94 - lr: 0.005000
2023-04-18 23:11:05,968 epoch 8 - iter 34/174 - loss 0.13036181 - time (sec): 52.82 - samples/sec: 425.27 - lr: 0.005000
2023-04-18 23:11:29,573 epoch 8 - iter 51/174 - loss 0.12581306 - time (sec): 76.43 - samples/sec: 437.05 - lr: 0.005000
2023-04-18 23:11:51,752 epoch 8 - iter 68/174 - loss 0.12966803 - time (sec): 98.60 - samples/sec: 445.59 - lr: 0.005000
2023-04-18 23:12:16,378 epoch 8 - iter 85/174 - loss 0.12688955 - time (sec): 123.23 - samples/sec: 444.27 - lr: 0.005000
2023-04-18 23:12:43,163 epoch 8 - iter 102/174 - loss 0.12612285 - time (sec): 150.02 - samples/sec: 439.43 - lr: 0.005000
2023-04-18 23:13:06,063 epoch 8 - iter 119/174 - loss 0.12662446 - time (sec): 172.92 - samples/sec: 439.26 - lr: 0.005000
2023-04-18 23:13:30,109

100%|██████████| 20/20 [00:11<00:00,  1.76it/s]

2023-04-18 23:14:34,815 Evaluating as a multi-label problem: False
2023-04-18 23:14:34,840 DEV : loss 0.08141843974590302 - f1-score (micro avg)  0.9353
2023-04-18 23:14:34,852 BAD EPOCHS (no improvement): 0
2023-04-18 23:14:34,856 saving best model





2023-04-18 23:14:36,353 ----------------------------------------------------------------------------------------------------
2023-04-18 23:14:59,566 epoch 9 - iter 17/174 - loss 0.13123296 - time (sec): 23.21 - samples/sec: 438.76 - lr: 0.005000
2023-04-18 23:15:24,053 epoch 9 - iter 34/174 - loss 0.12039967 - time (sec): 47.69 - samples/sec: 444.97 - lr: 0.005000
2023-04-18 23:15:48,612 epoch 9 - iter 51/174 - loss 0.12031202 - time (sec): 72.25 - samples/sec: 447.16 - lr: 0.005000
2023-04-18 23:16:11,658 epoch 9 - iter 68/174 - loss 0.12328567 - time (sec): 95.30 - samples/sec: 449.78 - lr: 0.005000
2023-04-18 23:16:36,359 epoch 9 - iter 85/174 - loss 0.12222636 - time (sec): 120.00 - samples/sec: 447.02 - lr: 0.005000
2023-04-18 23:17:03,851 epoch 9 - iter 102/174 - loss 0.11853210 - time (sec): 147.49 - samples/sec: 441.49 - lr: 0.005000
2023-04-18 23:17:28,627 epoch 9 - iter 119/174 - loss 0.11994465 - time (sec): 172.27 - samples/sec: 442.81 - lr: 0.005000
2023-04-18 23:17:53,220

100%|██████████| 20/20 [00:11<00:00,  1.75it/s]

2023-04-18 23:18:57,359 Evaluating as a multi-label problem: False
2023-04-18 23:18:57,385 DEV : loss 0.07748273760080338 - f1-score (micro avg)  0.939
2023-04-18 23:18:57,396 BAD EPOCHS (no improvement): 0
2023-04-18 23:18:57,401 saving best model





2023-04-18 23:18:58,881 ----------------------------------------------------------------------------------------------------
2023-04-18 23:19:21,711 epoch 10 - iter 17/174 - loss 0.11250636 - time (sec): 22.83 - samples/sec: 446.29 - lr: 0.005000
2023-04-18 23:19:44,456 epoch 10 - iter 34/174 - loss 0.11389318 - time (sec): 45.57 - samples/sec: 446.87 - lr: 0.005000
2023-04-18 23:20:09,318 epoch 10 - iter 51/174 - loss 0.11803931 - time (sec): 70.44 - samples/sec: 452.89 - lr: 0.005000
2023-04-18 23:20:33,499 epoch 10 - iter 68/174 - loss 0.11796891 - time (sec): 94.62 - samples/sec: 450.97 - lr: 0.005000
2023-04-18 23:20:58,913 epoch 10 - iter 85/174 - loss 0.11811340 - time (sec): 120.03 - samples/sec: 448.78 - lr: 0.005000
2023-04-18 23:21:23,881 epoch 10 - iter 102/174 - loss 0.11652581 - time (sec): 145.00 - samples/sec: 446.50 - lr: 0.005000
2023-04-18 23:21:47,607 epoch 10 - iter 119/174 - loss 0.11574878 - time (sec): 168.72 - samples/sec: 447.84 - lr: 0.005000
2023-04-18 23:22

100%|██████████| 20/20 [00:11<00:00,  1.80it/s]

2023-04-18 23:23:15,904 Evaluating as a multi-label problem: False
2023-04-18 23:23:15,931 DEV : loss 0.07097325474023819 - f1-score (micro avg)  0.9479
2023-04-18 23:23:15,943 BAD EPOCHS (no improvement): 0
2023-04-18 23:23:15,950 saving best model





2023-04-18 23:23:19,286 ----------------------------------------------------------------------------------------------------
2023-04-18 23:23:24,500 SequenceTagger predicts: Dictionary with 39 tags: O, S-Drug, B-Drug, E-Drug, I-Drug, S-Strength, B-Strength, E-Strength, I-Strength, S-Form, B-Form, E-Form, I-Form, S-Frequency, B-Frequency, E-Frequency, I-Frequency, S-Route, B-Route, E-Route, I-Route, S-Dosage, B-Dosage, E-Dosage, I-Dosage, S-Reason, B-Reason, E-Reason, I-Reason, S-Duration, B-Duration, E-Duration, I-Duration, S-ADE, B-ADE, E-ADE, I-ADE, <START>, <STOP>


100%|██████████| 1588/1588 [14:25<00:00,  1.83it/s]


2023-04-18 23:37:51,163 Evaluating as a multi-label problem: False
2023-04-18 23:37:53,272 0.8847	0.9185	0.9013	0.8278
2023-04-18 23:37:53,278 
Results:
- F-score (micro) 0.9013
- F-score (macro) 0.8149
- Accuracy 0.8278

By class:
              precision    recall  f1-score   support

        Drug     0.8697    0.9400    0.9035     61167
    Strength     0.9315    0.9573    0.9443     42957
        Form     0.9221    0.9199    0.9210     41417
   Frequency     0.8543    0.8591    0.8567     36495
       Route     0.9502    0.9449    0.9475     30583
      Dosage     0.9261    0.9432    0.9345     23506
      Reason     0.7297    0.7431    0.7363      9533
         ADE     0.2168    0.6759    0.3283      1299
    Duration     0.7432    0.7810    0.7616      1982

   micro avg     0.8847    0.9185    0.9013    248939
   macro avg     0.7937    0.8627    0.8149    248939
weighted avg     0.8923    0.9185    0.9044    248939

2023-04-18 23:37:53,279 ---------------------------------------

{'test_score': 0.9012660522975776,
 'dev_score_history': [0.8478642480983032,
  0.8683901292596946,
  0.8923346986541837,
  0.9022248243559718,
  0.9158604514805043,
  0.9214536928487691,
  0.926629640456007,
  0.9353321575543798,
  0.9389671361502349,
  0.9478952016485134],
 'train_loss_history': [0.26979004768039777,
  0.220478675166595,
  0.19417269020893674,
  0.1737861580474725,
  0.15649715688922705,
  0.14682087960186904,
  0.1336496293331674,
  0.12747490114608317,
  0.1220010997321469,
  0.11548307869139476],
 'dev_loss_history': [0.1806982308626175,
  0.14165449142456055,
  0.12228060513734818,
  0.11125482618808746,
  0.10138115286827087,
  0.09453363716602325,
  0.09184014052152634,
  0.08141843974590302,
  0.07748273760080338,
  0.07097325474023819]}

In [None]:
# Create example sentence
sentence = Sentence("Patients on 40 mg of Topelfate and Topoxy twice a day for stomachache generally suffer from headache")

# Token level predictions
tf_tagger.predict(sentence, force_token_predictions=True)
print(sentence.to_tagged_string())

# Predict tags and print
tf_tagger.predict(sentence)
print(sentence.to_tagged_string())

Sentence[17]: "Patients on 40 mg of Topelfate and Topoxy twice a day for stomachache generally suffer from headache" → ["40"/B-Strength, "mg"/E-Strength, "Topelfate"/S-Drug, "Topoxy"/S-Drug, "twice"/B-Frequency, "a"/I-Frequency, "day"/E-Frequency, "stomachache"/S-Reason, "headache"/S-ADE]
Sentence[17]: "Patients on 40 mg of Topelfate and Topoxy twice a day for stomachache generally suffer from headache" → ["40 mg"/Strength, "Topelfate"/Drug, "Topoxy"/Drug, "twice a day"/Frequency, "stomachache"/Reason, "headache"/ADE]
