In [3]:
from spacy import displacy
import transformers
from transformers import (AutoModelForTokenClassification, 
                          AutoTokenizer, 
                          pipeline)

In [4]:
model_checkpoint = "spanbert-large-cased-finetuned-ner/checkpoint-1000"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [6]:
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)

In [7]:
effect_ner_model = pipeline(task="ner", model=model, tokenizer=tokenizer, device=0,grouped_entities=True)

  f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'



We try out the first few examples of adverse effects from the Wikipedia page on adverse effects and visualize with the displaCy library:

https://en.wikipedia.org/wiki/Adverse_effect#Medications

In [9]:
def visualize_entities(sentence):
    tokens = effect_ner_model(sentence)
    entities = []
    
    for token in tokens:
        label = token["entity_group"]
        if label != "O":
            token["label"] = label
            entities.append(token)
    
    params = [{"text": sentence,
               "ents": entities,
               "title": None}]
    
    html = displacy.render(params, style="ent", manual=True, options={
        "colors": {
                   "DRUG": "#f08080",
                   "ADR": "#9bddff",
               },
    })
    

In [10]:
examples = [
    "Abortion, miscarriage or uterine hemorrhage associated with misoprostol (Cytotec), a labor-inducing drug.",
    "Addiction to many sedatives and analgesics, such as diazepam, morphine, etc.",
    "Birth defects associated with thalidomide",
    "Bleeding of the intestine associated with aspirin therapy",
    "Cardiovascular disease associated with COX-2 inhibitors (i.e. Vioxx)",
    "Deafness and kidney failure associated with gentamicin (an antibiotic)",
    "Having fever after taking paracetamol"
]

for example in examples:
    visualize_entities(example)
    print()




















