In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%capture
!pip install "flair" -q

In [None]:
import flair
from flair.data import Sentence
from flair.datasets import DataLoader
from flair.datasets import ColumnCorpus
from flair.embeddings import (
    WordEmbeddings, FlairEmbeddings, StackedEmbeddings, TransformerWordEmbeddings
)
from flair.models import RelationExtractor, SequenceTagger
from flair.trainers import ModelTrainer
flair.__version__

'0.12.2'

In [None]:
DATA_PATH = "/content/drive/Shareddrives/CIS522-Project/data/"
MODEL_PATH = "/content/drive/Shareddrives/CIS522-Project/models"

In [None]:
import pandas as pd
df = pd.read_parquet(DATA_PATH+'rel_train.parquet')
counts_dict = df["label"].value_counts().to_dict()

df1 = pd.read_parquet(DATA_PATH+'rel_test.parquet')
counts_dict1 = df1["label"].value_counts().to_dict()

In [None]:
train_d = dict(sorted(counts_dict.items()))
total = sum(list(train_d.values()))
for elem in train_d:
  train_d[elem] = train_d[elem]/float(total)
train_d

{'ADE-Drug': 0.030457271776811755,
 'Dosage-Drug': 0.11624387828096627,
 'Duration-Drug': 0.017691080173884335,
 'Form-Drug': 0.18307379078853245,
 'Frequency-Drug': 0.17360920046222417,
 'Reason-Drug': 0.1421614483024267,
 'Route-Drug': 0.15236889891597424,
 'Strength-Drug': 0.1843944312991801}

In [None]:
test_d = dict(sorted(counts_dict1.items()))
total = sum(list(test_d.values()))
for elem in test_d:
  test_d[elem] = test_d[elem]/float(total)
test_d

{'ADE-Drug': 0.031242008353934023,
 'Dosage-Drug': 0.11486659278833859,
 'Duration-Drug': 0.018157019861904357,
 'Form-Drug': 0.1864291194271588,
 'Frequency-Drug': 0.17193760122751683,
 'Reason-Drug': 0.14534140311993862,
 'Route-Drug': 0.1511380103997954,
 'Strength-Drug': 0.18088824482141336}

In [None]:
corpus = ColumnCorpus(
    DATA_PATH, {1: 'text', 2: 'ner'},
    train_file='flair_rel_train.txt', test_file='flair_rel_test.txt',
    comment_symbol="# "
)

2023-04-12 13:52:21,417 Reading data from /content/drive/Shareddrives/CIS522-Project/data
2023-04-12 13:52:21,429 Train: /content/drive/Shareddrives/CIS522-Project/data/flair_rel_train.txt
2023-04-12 13:52:21,433 Dev: None
2023-04-12 13:52:21,436 Test: /content/drive/Shareddrives/CIS522-Project/data/flair_rel_test.txt


In [None]:
corpus.train[0]

Sentence[29]: "He also may have recurrent seizures which should be treated with ativan IV or IM and do not neccessarily indicate patient needs to return to hospital unless they continue" → ["recurrent seizures"/Reason, "recurrent seizures -> ativan"/Reason-Drug, "ativan"/Drug, "IV"/Route, "IM"/Route]

In [None]:
corpus.test[0]

Sentence[14]: "MEDICATIONS : Lipitor , Tylenol with Codeine , Dilantin , previously on Decadron q.i.d" → ["Lipitor"/Drug, "Tylenol with Codeine"/Drug, "Dilantin"/Drug, "Decadron"/Drug, "q.i.d -> Decadron"/Frequency-Drug, "q.i.d"/Frequency]

In [None]:
label_dictionary = corpus.make_label_dictionary(label_type="relation", add_unk=False)
label_dictionary.add_item('O') # important to capture invalid relations
print(label_dictionary.get_items())

2023-04-10 06:59:02,099 Computing label dictionary. Progress:


32711it [00:00, 43156.08it/s]

2023-04-10 06:59:02,899 Dictionary created for label 'relation' with 8 values: Strength-Drug (seen 6063 times), Form-Drug (seen 5968 times), Frequency-Drug (seen 5691 times), Route-Drug (seen 4974 times), Reason-Drug (seen 4642 times), Dosage-Drug (seen 3785 times), ADE-Drug (seen 1011 times), Duration-Drug (seen 577 times)
['Strength-Drug', 'Form-Drug', 'Frequency-Drug', 'Route-Drug', 'Reason-Drug', 'Dosage-Drug', 'ADE-Drug', 'Duration-Drug', 'O']





In [None]:
weight_dict = {
    'Strength-Drug': 6021/6021,
    'Form-Drug': 6021/6005,
    'Frequency-Drug': 6021/5696,
    'Route-Drug': 6021/4934,
    'Reason-Drug': 6021/4669,
    'Dosage-Drug': 6021/3811,
    'ADE-Drug': 6021/996,
    'Duration-Drug': 6021/579,
}
weight_dict

{'Strength-Drug': 1.0,
 'Form-Drug': 1.0026644462947543,
 'Frequency-Drug': 1.057057584269663,
 'Route-Drug': 1.220308066477503,
 'Reason-Drug': 1.2895695009638037,
 'Dosage-Drug': 1.5799002886381528,
 'ADE-Drug': 6.045180722891566,
 'Duration-Drug': 10.398963730569948}

In [None]:
# embedding_types = [
#     FlairEmbeddings("pubmed-forward", fine_tune=True),
#     FlairEmbeddings("pubmed-backward", fine_tune=True),
# ]

# embeddings = StackedEmbeddings(embeddings=embedding_types)

# rel_extractor = RelationExtractor(
#     embeddings=embeddings,
#     label_type="relation",
#     entity_label_type='ner',
#     pooling_operation="first_last",
#     label_dictionary=label_dictionary,
#     loss_weights=weight_dict,
#     entity_pair_filters=[
#         ('Strength', 'Drug'),
#         ('Form', 'Drug'),
#         ('Frequency', 'Drug'),
#         ('Route', 'Drug'),
#         ('Reason', 'Drug'),
#         ('Dosage', 'Drug'),
#         ('ADE', 'Drug'),
#         ('Duration', 'Drug')
#     ]
# )

rel_extractor = RelationExtractor.load(
    f"{MODEL_PATH}/extractors/flair-embedding-rel/best-model.pt"
)

In [None]:
trainer = ModelTrainer(rel_extractor, corpus)

trainer.train(
    base_path=f"{MODEL_PATH}/extractors/flair-embedding-rel",
    train_with_dev=False,
    max_epochs=5,
    learning_rate=0.1,
    mini_batch_size=8,
    embeddings_storage_mode='none'
)



2023-04-10 06:59:17,776 ----------------------------------------------------------------------------------------------------
2023-04-10 06:59:17,820 Model: "RelationExtractor(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.1, inplace=False)
        (encoder): Embedding(275, 100)
        (rnn): LSTM(100, 2048)
      )
    )
    (list_embedding_1): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.1, inplace=False)
        (encoder): Embedding(275, 100)
        (rnn): LSTM(100, 2048)
      )
    )
  )
  (decoder): Linear(in_features=16384, out_features=9, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (locked_dropout): LockedDropout(p=0.0)
  (word_dropout): WordDropout(p=0.0)
  (loss_function): CrossEntropyLoss()
  (weights): {'Strength-Drug': 1.0, 'Form-Drug': 1.0026644462947543, 'Frequency-Drug': 1.057057584269663, 'Route-Drug': 1.220308066477503, 'Reason-Drug': 1.28956950

100%|██████████| 455/455 [02:04<00:00,  3.66it/s]


2023-04-10 07:36:56,602 Evaluating as a multi-label problem: False
2023-04-10 07:36:56,653 DEV : loss 0.04677500203251839 - f1-score (micro avg)  0.7725
2023-04-10 07:36:56,759 BAD EPOCHS (no improvement): 0
2023-04-10 07:36:56,766 saving best model
2023-04-10 07:36:57,240 ----------------------------------------------------------------------------------------------------
2023-04-10 07:41:00,290 epoch 2 - iter 408/4089 - loss 0.05602665 - time (sec): 243.04 - samples/sec: 325.82 - lr: 0.100000
2023-04-10 07:45:01,491 epoch 2 - iter 816/4089 - loss 0.05558043 - time (sec): 484.25 - samples/sec: 322.36 - lr: 0.100000
2023-04-10 07:48:55,792 epoch 2 - iter 1224/4089 - loss 0.05598320 - time (sec): 718.55 - samples/sec: 326.55 - lr: 0.100000
2023-04-10 07:53:01,654 epoch 2 - iter 1632/4089 - loss 0.05599502 - time (sec): 964.41 - samples/sec: 324.47 - lr: 0.100000
2023-04-10 07:57:00,004 epoch 2 - iter 2040/4089 - loss 0.05654076 - time (sec): 1202.76 - samples/sec: 324.33 - lr: 0.100000
2

100%|██████████| 455/455 [02:03<00:00,  3.67it/s]

2023-04-10 08:19:24,024 Evaluating as a multi-label problem: False





2023-04-10 08:19:24,097 DEV : loss 0.05283693969249725 - f1-score (micro avg)  0.7506
2023-04-10 08:19:24,247 BAD EPOCHS (no improvement): 1
2023-04-10 08:19:24,256 ----------------------------------------------------------------------------------------------------
2023-04-10 08:23:27,432 epoch 3 - iter 408/4089 - loss 0.04955468 - time (sec): 243.17 - samples/sec: 324.34 - lr: 0.100000
2023-04-10 08:27:28,616 epoch 3 - iter 816/4089 - loss 0.05101356 - time (sec): 484.36 - samples/sec: 320.45 - lr: 0.100000
2023-04-10 08:31:34,290 epoch 3 - iter 1224/4089 - loss 0.05101134 - time (sec): 730.03 - samples/sec: 319.24 - lr: 0.100000
2023-04-10 08:35:40,693 epoch 3 - iter 1632/4089 - loss 0.05073283 - time (sec): 976.44 - samples/sec: 319.06 - lr: 0.100000
2023-04-10 08:39:41,051 epoch 3 - iter 2040/4089 - loss 0.05053183 - time (sec): 1216.79 - samples/sec: 320.79 - lr: 0.100000
2023-04-10 08:43:47,500 epoch 3 - iter 2448/4089 - loss 0.05063105 - time (sec): 1463.24 - samples/sec: 318.57

100%|██████████| 455/455 [02:03<00:00,  3.68it/s]

2023-04-10 09:01:58,551 Evaluating as a multi-label problem: False
2023-04-10 09:01:58,592 DEV : loss 0.055398471653461456 - f1-score (micro avg)  0.7647





2023-04-10 09:01:58,690 BAD EPOCHS (no improvement): 2
2023-04-10 09:01:58,704 ----------------------------------------------------------------------------------------------------
2023-04-10 09:05:51,901 epoch 4 - iter 408/4089 - loss 0.04410451 - time (sec): 233.20 - samples/sec: 333.09 - lr: 0.100000
2023-04-10 09:09:53,077 epoch 4 - iter 816/4089 - loss 0.04462273 - time (sec): 474.37 - samples/sec: 328.49 - lr: 0.100000
2023-04-10 09:13:56,548 epoch 4 - iter 1224/4089 - loss 0.04474615 - time (sec): 717.84 - samples/sec: 326.99 - lr: 0.100000
2023-04-10 09:17:55,126 epoch 4 - iter 1632/4089 - loss 0.04459590 - time (sec): 956.42 - samples/sec: 328.25 - lr: 0.100000
2023-04-10 09:21:54,601 epoch 4 - iter 2040/4089 - loss 0.04453981 - time (sec): 1195.90 - samples/sec: 329.91 - lr: 0.100000
2023-04-10 09:26:05,375 epoch 4 - iter 2448/4089 - loss 0.04517755 - time (sec): 1446.67 - samples/sec: 328.59 - lr: 0.100000
2023-04-10 09:30:10,133 epoch 4 - iter 2856/4089 - loss 0.04523071 - t

100%|██████████| 455/455 [02:03<00:00,  3.69it/s]

2023-04-10 09:44:18,121 Evaluating as a multi-label problem: False
2023-04-10 09:44:18,163 DEV : loss 0.05601434409618378 - f1-score (micro avg)  0.7597





2023-04-10 09:44:18,262 BAD EPOCHS (no improvement): 3
2023-04-10 09:44:18,268 ----------------------------------------------------------------------------------------------------
2023-04-10 09:48:19,467 epoch 5 - iter 408/4089 - loss 0.04270323 - time (sec): 241.20 - samples/sec: 326.25 - lr: 0.100000
2023-04-10 09:52:26,236 epoch 5 - iter 816/4089 - loss 0.04223448 - time (sec): 487.96 - samples/sec: 326.08 - lr: 0.100000
2023-04-10 09:56:29,273 epoch 5 - iter 1224/4089 - loss 0.04259352 - time (sec): 731.00 - samples/sec: 325.40 - lr: 0.100000
2023-04-10 10:00:25,258 epoch 5 - iter 1632/4089 - loss 0.04241547 - time (sec): 966.99 - samples/sec: 326.13 - lr: 0.100000
2023-04-10 10:04:22,193 epoch 5 - iter 2040/4089 - loss 0.04249717 - time (sec): 1203.92 - samples/sec: 326.56 - lr: 0.100000
2023-04-10 10:08:20,566 epoch 5 - iter 2448/4089 - loss 0.04243990 - time (sec): 1442.29 - samples/sec: 326.56 - lr: 0.100000
2023-04-10 10:12:27,995 epoch 5 - iter 2856/4089 - loss 0.04242879 - t

100%|██████████| 455/455 [02:03<00:00,  3.68it/s]

2023-04-10 10:26:46,102 Evaluating as a multi-label problem: False
2023-04-10 10:26:46,144 DEV : loss 0.043203406035900116 - f1-score (micro avg)  0.8255





2023-04-10 10:26:46,245 BAD EPOCHS (no improvement): 0
2023-04-10 10:26:46,252 saving best model
2023-04-10 10:26:47,314 ----------------------------------------------------------------------------------------------------


 78%|███████▊  | 2289/2933 [09:46<02:42,  3.97it/s]

In [None]:
# run evaluation procedure
result = rel_extractor.evaluate(
    corpus.test, gold_label_type='relation', mini_batch_size=64
)
print(result.detailed_results)

100%|██████████| 367/367 [04:32<00:00,  1.35it/s]

2023-04-12 14:04:15,026 Evaluating as a multi-label problem: False






Results:
- F-score (micro) 0.8235
- F-score (macro) 0.7998
- Accuracy 0.7124

By class:
                precision    recall  f1-score   support

     Form-Drug     0.8904    0.8267    0.8574      4374
Frequency-Drug     0.8763    0.9251    0.9000      4034
    Route-Drug     0.6760    0.8739    0.7624      3546
 Strength-Drug     0.9228    0.7717    0.8405      4244
   Reason-Drug     0.6987    0.8487    0.7664      3410
   Dosage-Drug     0.8643    0.8883    0.8761      2695
      ADE-Drug     0.5217    0.7872    0.6275       733
 Duration-Drug     0.6818    0.8803    0.7684       426

     micro avg     0.7978    0.8508    0.8235     23462
     macro avg     0.7665    0.8502    0.7998     23462
  weighted avg     0.8153    0.8508    0.8274     23462



In [None]:
tagger = SequenceTagger.load(f"{MODEL_PATH}/taggers/lstm-crf-augmented/final-model.pt")

In [None]:
# create example sentence
sentence = Sentence("Patients on 40 mg of Topelfate and Topoxy twice a day generally suffer from headache")

# predict tags and print
tagger.predict(sentence)
rel_extractor.predict(sentence)

print(sentence.to_tagged_string())