In [1]:
import spacy
from spacy.tokens import Span
from spacy import displacy
import numpy as np
from transformers import pipeline

In [2]:

import yaml
from collections import defaultdict
from scripts.eval_phee.sel2record_phee import *
from uie.sel2record.sel2record import SEL2Record
from scripts.eval_phee.phee_metric import compute_metric

class SpacyVis:
    def __init__(self):
        self.nlp = spacy.blank("en")
        self.ruler = self.nlp.add_pipe("span_ruler", config={"phrase_matcher_attr": "LOWER"})
        self.schema_dict = SEL2Record.load_schema_dict("data/converted_data/text2spotasoc/event/phee_2")

    @staticmethod
    def _decode(pred, schema_dict):
        left_bracket = '【'
        right_bracket = '】'
        brackets = left_bracket + right_bracket

        pred = convert_bracket(pred)
        pred = clean_text(pred)
        try:
            if not check_well_form(pred):
                pred = add_bracket(pred)
            pred_tree = ParentedTree.fromstring(pred, brackets=brackets)

        except ValueError:
            logger.debug('ill-formed', pred)
            pred_tree = ParentedTree.fromstring(
                left_bracket + right_bracket,
                brackets=brackets
            )

        instance = {}
        instance['pred_spot'], instance['pred_asoc'], instance['pred_record'] = get_record_list(
                    sel_tree=pred_tree,
                    schema_dict=schema_dict
                )

        return instance

    def convert_to_record(self, output):
        pred_instance = self._decode(output, self.schema_dict)
        pred_record = proprocessing_graph_record(
            pred_instance,
            self.schema_dict
        )
        out_record = []
        for event in pred_record["event"]:
            out_record.append(event)

        return out_record

    def construct_labels(self, sentence, pred):
        patterns = []
        for eid, event in enumerate(pred):
            if event['type'] == 'adverse event':
                # evt_str = 'ADE%d'%eid
                evt_str = 'ADE'
            elif  event['type'] == 'potential therapeutic event':
                # evt_str = 'PTE%d'%eid
                evt_str = 'PTE'

            if "trigger" in event:
                patterns.append({"label": evt_str+".trigger", "pattern": event["trigger"]})

            for arg_type, arg_span in event['roles']:
                arg_type = "_".join(arg_type.split(" "))
                patterns.append({"label": evt_str+"."+arg_type, "pattern": arg_span})
            
        self.ruler.clear()
        self.ruler.add_patterns(patterns)
        sent_doc = self.nlp(sentence)

        sent_doc.spans["sc"] = []
        for span in sent_doc.spans["ruler"]:
            sent_doc.spans["sc"].append(Span(sent_doc, span.start, span.end, span.label_))

        return sent_doc
    
    
    
    def visualise_spans(self, sentence, output):
        record = self.convert_to_record(output)
        docs = self.construct_labels(sentence, record)
        displacy.render(docs, style="span")

    def visualise_gpt_spans(self, sentence, gpt_ouput):
        docs = self.construct_labels(sentence, gpt_ouput)
        displacy.render(docs, style="span")

In [3]:
visulaizer = SpacyVis()

In [4]:
def convert_gpt_result(json_file):
        def _filter_span(span):
            span = span.lower()
            if "n/a" in span:
                return True
            if "null" in span or "none" in span:
                return True
            if "not mentioned" in span:
                return True
            return False
        
        output = []
        with open(json_file, 'r') as f:
             result = json.load(f)
        result = json.loads(result['answer'])
        for event in result:
            evt = {'type': event["event_type"], "roles":[]}
            for arg_type, arg_span in event["arguments"].items():
                if arg_type == 'indication':
                    arg_type = 'treatment.disorder'
                if not _filter_span(arg_span):
                    spans = arg_span.split(";")
                    for span in spans:
                        evt["roles"].append((arg_type, span.strip()))
            output.append(evt)

        return output

In [5]:
gold_file = "dataset_processing/data/converted_data/text2spotasoc/event/phee2_cross1/test.json"
gpt_folder = "chatgpt_few_shot/output/BM25-Type_demo_5/test"
syn_file = "dataset_processing/data/converted_data/text2spotasoc/event/phee2_cross3/train_aug.json"

In [17]:
import random
import json

with open(gold_file, 'r') as f:
    gold_instances = {}
    for line in f.readlines():
        d = json.loads(line)
        gold_instances[d["text_id"]] = (d["text"], d["record"])

with open(syn_file, 'r') as f:
    syn_instances = {}
    for line in f.readlines():
        d = json.loads(line)
        syn_instances["_".join(d["text_id"].split("_")[:2])] = (d["text"], d["record"])

sample_id = random.sample(gold_instances.keys(),1)[0]
# sample_id = '3824704_1'

sentence, gold_record = gold_instances[sample_id]
gpt_output = convert_gpt_result(os.path.join(gpt_folder, sample_id+".json"))
syn_sentence, syn_record = syn_instances[sample_id]


print(sample_id)
print("Annotation:")
visulaizer.visualise_spans(sentence, gold_record)

print("ChatGPT Prediction:")
visulaizer.visualise_gpt_spans(sentence, gpt_output)

print("Synthesized sentence:")
print(syn_sentence)
visulaizer.visualise_spans(syn_sentence, syn_record)


12659609_2
Annotation:


ChatGPT Prediction:


Synthesized sentence:
A case of severe anisocoria was reported due to the use of transdermal scopolamine .


In [7]:
# error analysis 
gold_file = "dataset_processing/data/converted_data/text2spotasoc/event/phee2_cross1/test.json"
pred1_file = "hf_models/cross_flan_t5_large_instruction_finetune_phee2_spot_asoc_noise_0_order/cross1/test_preds_seq2seq.txt"
pred2_file = "hf_models/cross_flan_t5_large_instruction_finetune_phee2aug_gpt_spot_asoc_noise_0/cross1/test_preds_seq2seq.txt"

text_ids = []
with open(gold_file, 'r') as f:
    gold_instances = {}
    for line in f.readlines():
        d = json.loads(line)
        text_ids.append(d["text_id"])
        gold_instances[d["text_id"]] = (d["text"], d["record"])

with open(pred1_file) as f:
    pred1_records = f.readlines()

with open(pred2_file) as f:
    pred2_records = f.readlines()

EVAL_ARG = 'race'

vis_id = []
for sid, pred1, pred2 in zip(text_ids, pred1_records, pred2_records):
    text, gold = gold_instances[sid] 
    gold = visulaizer.convert_to_record(gold)
    record1 = visulaizer.convert_to_record(pred1)
    record2 = visulaizer.convert_to_record(pred2)

    gold_args = []
    for evt in gold:
        for arg, span in evt['roles']:
            if arg == EVAL_ARG:
                gold_args.append(span)

    p1_args = []
    for evt in record1:
        for arg, span in evt['roles']:
            if arg == EVAL_ARG:
                p1_args.append(span)

    p2_args = []
    for evt in record2:
        for arg, span in evt['roles']:
            if arg == EVAL_ARG:
                p2_args.append(span)

    if gold_args == p1_args and gold_args != p2_args:
        vis_id.append(sid)

print(len(vis_id))

0


In [113]:
sid = vis_id[3]
print(sid)
print(gold_instances[sid][0])
print("Annotation:")
visulaizer.visualise_spans(gold_instances[sid][0], gold_instances[sid][1])

print("Finetune Prediction:")
visulaizer.visualise_spans(gold_instances[sid][0], pred1_records[text_ids.index(sid)])

print("Aug Prediction:")
visulaizer.visualise_spans(gold_instances[sid][0], pred2_records[text_ids.index(sid)])

21630612_4
We report a case of severe simvastatin - induced rhabdomyolysis triggered by the addition of amiodarone to previously well - tolerated chronic statin therapy .
Annotation:


Finetune Prediction:


Aug Prediction:
