In [1]:
# use this demo to compare eeqa result & gen_qa attention analysis

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.chdir("/home/zha0yuewarwick/projects/PhEE/gen_qa")
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, utils
from bertviz import head_view
import argparse
import json
import random

In [3]:
args = argparse.Namespace()
args.num_beams = 3
args.n_best_size = 10
args.test_file = "../data/phee_genqa/argument/test_2.json"
args.model_name_or_path = "/home/zha0yuewarwick/projects/PhEE/gen_qa/model/SciFive-base-PMC/arg2"
args.max_answer_length = 30

In [4]:
ARG_TEMPLATES = {
        "Subject": ["Subject", 
        "Subject in [EVENT]",
        "Subject in [TRIGGER]",
        "Who is the subject?",
        "Who is the subject in [EVENT]?",
        "Who is the subject in [TRIGGER]?",
        "Who is treated?",
        "Who is treated in [EVENT]?",
        "Who is treated in [TRIGGER]?"
        ], 
        "Treatment": ["Treatment",
        "Treatment in [EVENT]",
        "Treatment in [TRIGGER]",
        "What is the treatment?",
        "What is the treatment in [EVENT]?",
        "What is the treatment in [TRIGGER]?",
        "What treatment is given to the patient?",
        "What treatment is given to the patient in [EVENT]?",
        "What treatment is given to the patient in [TRIGGER]?",
        ],
        "Effect":["Effect",
        "Effect in [EVENT]",
        "Effect in [TRIGGER]",
        "What is the effect?",
        "What is the effect in [EVENT]?",
        "What is the effect in [TRIGGER]?",
        "What effect does the treatment cause?",
        "What effect does the treatment cause in [EVENT]?",
        "What effect does the treatment cause in [TRIGGER]?"
        ]
    }

In [5]:
# read cases from EEQA prediction result
# compared EEQA question template: arg_query + in trigger(5)
eeqa_result_file = "/home/zha0yuewarwick/projects/PhEE/eeqa/model/biobert_rst/argument/template5/pred_outputs.json"
eeqa_results = []
with open(eeqa_result_file, "r") as f:
    for line in f.readlines():
        eeqa_results.append(json.loads(line))

In [6]:
# load generative QA model
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

### Start Demo

In [7]:
# sample a case
TEST_ROLE = "Treatment"
random_id = random.randint(0, len(eeqa_results))
select_id = 140
select_id

140

In [8]:
def get_sample_result(select_id, test_role):
    eeqa_rst = eeqa_results[select_id]
    sent_tokens = eeqa_rst['sentence']
    context = " ".join(sent_tokens)
    event_type = eeqa_rst['event'][0][0][-1]

    trig_st = eeqa_rst['event'][0][0][0]
    trigger_text = sent_tokens[eeqa_rst['event'][0][0][0]]
    
    # get gold span
    gold_span = []
    for argument in eeqa_rst['event'][0][1:]: # assume a sentence only includes one event
        if argument[-1] == test_role:
            gold_span.append(" ".join(sent_tokens[argument[0]: argument[1]+1]))
    gold_span = "; ".join(gold_span)

    # get eeqa predicted span
    eeqa_span = []
    for argument in eeqa_rst['arg_pred']:
        if argument[-1] == test_role:
            eeqa_span.append(" ".join(sent_tokens[argument[0]: argument[1]+1]))
    eeqa_span = "; ".join(eeqa_span)
    
    # get gen_qa input and output
    question = ARG_TEMPLATES[test_role][0] # TODO: modify this to match eeqa method, but we currently only train the argument type
    question = question.replace("[EVENT]", " ".join(event_type.split('_')).lower())
    question = question.replace("[TRIGGER]", trigger_text)
    
    input_text = " ".join(["question:", question.lstrip(), "context:", context.lstrip()])
    
    # tokenize the input
    encoder_input_ids = tokenizer(input_text, return_tensors="pt", add_special_tokens=True).input_ids
    # get model outputs
    decoder_input_ids = model.generate(encoder_input_ids, num_beams=args.num_beams, top_p=args.n_best_size)

    encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
    decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])
    gen_qa_span = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
    
    return {
        'context': context,
        'event_type': event_type,
        'gold_span': gold_span,
        'eeqa_span': eeqa_span,
        'gen_qa_span': gen_qa_span,
        'encoder_text': encoder_text,
        'decoder_text': decoder_text,
        'encoder_input_ids': encoder_input_ids,
        'decoder_input_ids': decoder_input_ids
    }

In [9]:
sample_rst = get_sample_result(select_id, TEST_ROLE)

### Output Comparing Results

In [10]:
# compare outputs
print("CONTEXT: %s"%sample_rst["context"])
print("EVENT TYPE: %s, QUERY_ROLE:%s"%(sample_rst["event_type"], TEST_ROLE))
print("GOLD SPAN: %s"%sample_rst["gold_span"])
print("EEQA PRED: %s"%sample_rst["eeqa_span"])
print("GenQA PRED: %s"%sample_rst["gen_qa_span"])

CONTEXT: 1 . 
 Artelinic acid ( AL ) , a water - soluble artemisinin analogue for treatment of multidrug resistant malaria , is metabolized to the active metabolite dihydroqinghaosu ( DQHS ) solely by   CYP3A4/5 .
EVENT TYPE: Potential_therapeutic_event, QUERY_ROLE:Treatment
GOLD SPAN: Artelinic acid ( AL ) , a water - soluble artemisinin analogue; metabolized to the active metabolite dihydroqinghaosu ( DQHS ) solely by   CYP3A4/5
EEQA PRED: Artelinic acid ( AL )
GenQA PRED: Artelinic acid ( AL ), a water - soluble art


### Visualize Attentions

In [11]:
outputs = model(input_ids=sample_rst["encoder_input_ids"], decoder_input_ids=sample_rst["decoder_input_ids"])

In [None]:
head_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= sample_rst["encoder_text"],
    decoder_tokens = sample_rst["decoder_text"]
)

### Batch Analysis

In [None]:
# only sample not exact match cases
TEST_ROLE = "Treatment"
SAMPLE_NUM = 5
case_id = 0
print(TEST_ROLE)
while case_id < SAMPLE_NUM:
    random_id = random.randint(0, len(eeqa_results))
    select_id = random_id

    sample_rst = get_sample_result(select_id, TEST_ROLE)
    if sample_rst["gold_span"] == sample_rst["gen_qa_span"]:
        continue
    
    print("Sample case %d:"%select_id)
    print("CONTEXT: %s"%sample_rst["context"])
    # print("EVENT TYPE: %s, QUERY_ROLE:%s"%(sample_rst["event_type"], TEST_ROLE))
    print("GOLD SPAN: %s"%sample_rst["gold_span"])
    print("EEQA PRED: %s"%sample_rst["eeqa_span"])
    print("GenQA PRED: %s"%sample_rst["gen_qa_span"])
    print("")
    case_id += 1