In [1]:
# use this demo to compare eeqa VS generative QA result

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.chdir("/home/zha0yuewarwick/projects/PhEE/gen_qa")

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, utils
from bertviz import head_view
import argparse
import json
import random
import spacy
from spacy.tokens import Span
import medspacy
from medspacy.ner import TargetRule
from medspacy.visualization import visualize_ent

In [3]:
args = argparse.Namespace()
args.question_type=8
args.num_beams = 3
args.n_best_size = 20
args.test_file = "../data/phee_genqa/argument/dev_%d.json"%args.question_type
args.model_name_or_path = "/home/zha0yuewarwick/projects/PhEE/gen_qa/model/SciFive-base-PMC/arg%d"%args.question_type
args.max_answer_length = 150


In [4]:
ARG_TEMPLATES = {
        "Subject": ["Subject", 
        "Subject in [EVENT]",
        "Subject in [TRIGGER]",
        "Who is the subject?",
        "Who is the subject in [EVENT]?",
        "Who is the subject in [TRIGGER]?",
        "Who is treated?",
        "Who is treated in [EVENT]?",
        "Who is treated in [TRIGGER]?"
        ], 
        "Treatment": ["Treatment",
        "Treatment in [EVENT]",
        "Treatment in [TRIGGER]",
        "What is the treatment?",
        "What is the treatment in [EVENT]?",
        "What is the treatment in [TRIGGER]?",
        "What treatment is given to the patient?",
        "What treatment is given to the patient in [EVENT]?",
        "What treatment is given to the patient in [TRIGGER]?",
        ],
        "Effect":["Effect",
        "Effect in [EVENT]",
        "Effect in [TRIGGER]",
        "What is the effect?",
        "What is the effect in [EVENT]?",
        "What is the effect in [TRIGGER]?",
        "What effect does the treatment cause?",
        "What effect does the treatment cause in [EVENT]?",
        "What effect does the treatment cause in [TRIGGER]?"
        ]
    }

In [5]:
# read cases from EEQA prediction result
# compared EEQA question template: arg_query + in trigger(5)
eeqa_result_file = "/home/zha0yuewarwick/projects/PhEE/eeqa/model/biobert_rst/argument/template5/no_overlap/pred_outputs.json"
eeqa_results = []
with open(eeqa_result_file, "r") as f:
    for line in f.readlines():
        eeqa_results.append(json.loads(line))

In [6]:
# load generative QA model
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

In [7]:
def get_sample_result(select_id, test_role):
    eeqa_rst = eeqa_results[select_id]
    sent_tokens = eeqa_rst['sentence']
    context = " ".join(sent_tokens)
    event_type = eeqa_rst['event'][0][0][-1]

    trig_st = eeqa_rst['event'][0][0][0]
    trigger_text = sent_tokens[eeqa_rst['event'][0][0][0]]
    
    # get gold span
    gold_span = []
    for argument in eeqa_rst['event'][0][1:]: # assume a sentence only includes one event
        if argument[-1] == test_role:
            gold_span.append(" ".join(sent_tokens[argument[0]: argument[1]+1]))
    gold_span = "; ".join(gold_span)

    # get eeqa predicted span
    eeqa_span = []
    for argument in eeqa_rst['arg_pred']:
        if argument[-1] == test_role:
            eeqa_span.append(" ".join(sent_tokens[argument[0]: argument[1]+1]))
    eeqa_span = "; ".join(eeqa_span)
    
    # get gen_qa input and output
    question = ARG_TEMPLATES[test_role][args.question_type] # TODO: modify this to match eeqa method, but we currently only train the argument type
    question = question.replace("[EVENT]", " ".join(event_type.split('_')).lower())
    question = question.replace("[TRIGGER]", trigger_text)
    
    input_text = " ".join(["question:", question.lstrip(), "context:", context.lstrip()])
    
    # tokenize the input
    encoder_input_ids = tokenizer(input_text, return_tensors="pt", add_special_tokens=True).input_ids
    # get model outputs
    decoder_input_ids = model.generate(encoder_input_ids, num_beams=args.num_beams, top_p=args.n_best_size, max_length=args.max_answer_length)

    encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
    decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])
    gen_qa_span = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
    
    return {
        'context': context,
        'event_type': event_type,
        'gold_span': gold_span,
        'eeqa_span': eeqa_span,
        'gen_qa_span': gen_qa_span,
        'encoder_text': encoder_text,
        'decoder_text': decoder_text,
        'encoder_input_ids': encoder_input_ids,
        'decoder_input_ids': decoder_input_ids
    }

In [8]:
sample_rst = get_sample_result(718, "Treatment")
print(sample_rst)
print(len(sample_rst["encoder_text"]))
print(len(sample_rst["decoder_text"]))

{'context': "Here we present the case of a woman who received high doses of methylprednisolone ( 1 g iv daily ) for active Graves ' ophthalmopathy , and developed severe hypertension followed by myocardial infarction on the fifth day of treatment .", 'event_type': 'Adverse_event', 'gold_span': 'high doses of methylprednisolone ( 1 g iv daily )', 'eeqa_span': 'high doses of methylprednisolone ( 1 g iv daily )', 'gen_qa_span': 'high doses of methylprednisolone ( 1 g iv daily )', 'encoder_text': ['▁question', ':', '▁What', '▁treatment', '▁is', '▁given', '▁to', '▁the', '▁patient', '▁in', '▁received', '?', '▁context', ':', '▁Here', '▁we', '▁present', '▁the', '▁case', '▁of', '▁', 'a', '▁woman', '▁who', '▁received', '▁high', '▁dose', 's', '▁of', '▁', 'methyl', 'pre', 'd', 'n', 'i', 'sol', 'one', '▁(', '▁1', '▁', 'g', '▁', 'i', 'v', '▁daily', '▁', ')', '▁for', '▁active', '▁Grav', 'e', 's', '▁', "'", '▁', 'o', 'phthal', 'm', 'opathy', '▁', ',', '▁and', '▁developed', '▁severe', '▁hyper', 'tensio

### Start Demo

In [19]:
import warnings

# sample a case
random_id = random.randint(0, len(eeqa_results))
select_id = random_id
print("random selected case: %d"%select_id)


ROLES = ["Subject", "Treatment", "Effect"]
rst = {}
for role in ROLES:
    sample_rst = get_sample_result(select_id, role)
    rst[role] = sample_rst

context = rst["Subject"]["context"]
print("event type: %s"%rst["Subject"]["event_type"])
with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("ignore")
    for task in ["gold_span", "eeqa_span", "gen_qa_span"]:
        nlp = medspacy.load()
        # print("\n")
        print(task.upper())
        # Add rules for target concept extraction
        target_matcher = nlp.get_pipe("medspacy_target_matcher")

        target_rules = []
        for role in ROLES:
            print("%s: %s"%(role, rst[role][task]))
            arguments = rst[role][task].split("; ")
            
            for sp in arguments:
                target_rules.append(TargetRule(sp, role))

        target_matcher.add(target_rules)
        doc = nlp(context)
        # print("\n")
        visualize_ent(doc)
        
    
    

random selected case: 774
event type: Adverse_event
GOLD_SPAN
Subject: 
Treatment: Dipyrone , also known as metamizole
Effect: agranulocytosis


EEQA_SPAN
Subject: 
Treatment: agranulocytosis
Effect: agranulocytosis


GEN_QA_SPAN
Subject: 
Treatment: Dipyrone, also known as metamizole, is an analgesic and antipyretic drug
Effect: agranulocytosis
