In [8]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
%%capture
!pip install "flair" -q
!pip install "scispacy" -q
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_sm-0.5.1.tar.gz -q

LSTM-CRF : [link](https://github.com/flairNLP/flair/blob/master/resources/docs/HUNFLAIR_TUTORIAL_2_TRAINING.md)

Transformer: [link](https://github.com/flairNLP/flair/blob/master/resources/docs/TUTORIAL_TRAINING_SEQUENCE_LABELER.md)

In [1]:
import flair
from flair.data import Sentence
from flair.datasets import ColumnCorpus
from flair.embeddings import (
    WordEmbeddings, FlairEmbeddings, StackedEmbeddings
)
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
flair.__version__

'0.12.2'

In [2]:
DATA_PATH = "/content/drive/MyDrive/Courses/2. Spring 23/1. CIS522/Project/data"
MODEL_PATH = "/content/drive/MyDrive/Courses/2. Spring 23/1. CIS522/Project/models"

In [3]:
columns = {0:"text", 1:"ner"}

filename = "flair_train.txt"
test_file = "flair_test.txt"

corpus = ColumnCorpus(
    DATA_PATH, columns, train_file=filename, test_file=test_file
)
tag_dictionary = corpus.make_label_dictionary(label_type="ner", add_unk=False)
print(tag_dictionary.get_items())

2023-04-04 14:09:33,805 Reading data from /content/drive/MyDrive/Courses/2. Spring 23/1. CIS522/Project/data
2023-04-04 14:09:33,808 Train: /content/drive/MyDrive/Courses/2. Spring 23/1. CIS522/Project/data/flair_train.txt
2023-04-04 14:09:33,810 Dev: None
2023-04-04 14:09:33,811 Test: /content/drive/MyDrive/Courses/2. Spring 23/1. CIS522/Project/data/flair_test.txt


In [25]:
weight_dict = {
    'Drug': 87168/87168,
    'Strength': 87168/60400,
    'Form': 87168/57184,
    'Frequency': 87168/49699,
    'Route': 87168/41022,
    'Dosage': 87168/33289,
    'Reason': 87168/14242,
    'Duration': 87168/3350,
    'ADE': 87168/2260,
}
weight_dict

{'Drug': 1.0,
 'Strength': 1.4431788079470198,
 'Form': 1.5243424734191382,
 'Frequency': 1.753918589911266,
 'Route': 2.1249085856369754,
 'Dosage': 2.6185226351046893,
 'Reason': 6.120488695407948,
 'Duration': 26.020298507462687,
 'ADE': 38.56991150442478}

In [4]:
corpus.train[0]

Sentence[29]: "He also may have recurrent seizures which should be treated with ativan IV or IM and do not neccessarily indicate patient needs to return to hospital unless they continue" → ["recurrent seizures"/Reason, "ativan"/Drug, "IV"/Route, "IM"/Route]

In [5]:
corpus.test[0]

Sentence[14]: "MEDICATIONS : Lipitor , Tylenol with Codeine , Dilantin , previously on Decadron q.i.d" → ["Lipitor"/Drug, "Tylenol with Codeine"/Drug, "Dilantin"/Drug, "Decadron"/Drug, "q.i.d"/Frequency]

In [24]:
embedding_types = [
    # word embeddings trained on PubMed and PMC
    WordEmbeddings("pubmed"),
    # flair embeddings trained on PubMed and PMC
    FlairEmbeddings("pubmed-forward"),
    FlairEmbeddings("pubmed-backward"),
]

embeddings = StackedEmbeddings(embeddings=embedding_types)

In [26]:
tagger = SequenceTagger(
    hidden_size=256,
    embeddings=embeddings,
    tag_dictionary=tag_dictionary,
    tag_type="ner",
    rnn_type='LSTM',
    use_crf=True,
    locked_dropout=0.5,
    loss_weights=weight_dict
)

2023-04-04 14:53:23,223 SequenceTagger predicts: Dictionary with 37 tags: O, S-Drug, B-Drug, E-Drug, I-Drug, S-Strength, B-Strength, E-Strength, I-Strength, S-Form, B-Form, E-Form, I-Form, S-Frequency, B-Frequency, E-Frequency, I-Frequency, S-Route, B-Route, E-Route, I-Route, S-Dosage, B-Dosage, E-Dosage, I-Dosage, S-Reason, B-Reason, E-Reason, I-Reason, S-Duration, B-Duration, E-Duration, I-Duration, S-ADE, B-ADE, E-ADE, I-ADE


In [28]:
trainer = ModelTrainer(tagger, corpus)

# 4. train on the target corpus
trainer.train(
    base_path=f"{MODEL_PATH}/taggers/lstm-crf",
    train_with_dev=False,
    max_epochs=1,
    learning_rate=0.1,
    mini_batch_size=64,
    embeddings_storage_mode='none'
)

In [29]:
# loaded_model = SequenceTagger.load(f"{MODEL_PATH}/taggers/lstm-crf/final-model.pt")

In [30]:
# # create example sentence
# sentence = Sentence("Women who smoke 20 cigarettes a day are four times more likely to develop breast cancer.")

# # predict tags and print
# loaded_model.predict(sentence)

# print(sentence.to_tagged_string())