In [1]:
from collections import defaultdict
from datasets import load_dataset
from enum import Enum
from datetime import datetime, timedelta
from tqdm.auto import tqdm

import evaluate
import json
import numpy as np
import os

In [2]:
INPUT_FOLDER = 'data/bioc/json/step_2_span_clf'

class DatasetSplit(Enum):
    train = 0
    validation = 1
    test = 2
    
class PicoType(Enum):
    PARTICIPANTS = 4
    INTERVENTIONS = 2
    OUTCOMES = 1

In [3]:
class Span:
    def __init__(self, start, length):
        self.start = start
        self.length = length
        self.end = self.start + self.length

    def __repr__(self):
        return f'Span(start={self.start}, length={self.length})'

    def __eq__(self, other):
        return (self.start, int(self.length)) == (other.start, int(other.length))

In [4]:
def extract_spans_from_labels(label_sequence, pico_type_value):
    span_start = [0 for l in label_sequence]
    span_length = [0.0 for l in label_sequence]
    labels = [pico_type_value&label if label > 0 else 0 for label in label_sequence]
    
    for i, label in enumerate(labels):
        if label > 0:
            if i==0 or labels[i-1] <= 0:
                span_start[i] = 1
                start = i
            span_length[start] += 1
            
    spans = []
    for i in range(len(span_start)):
        if span_start[i]:
            s = Span(start=i, length=span_length[i])
            spans.append(s)
    return spans

In [5]:
def eval_pred_single_sample(prediction, reference):
    tp, fp, fn = 0.0, 0.0, 0.0
    prediction.sort(reverse=False, key=lambda x: x.start)
    reference.sort(reverse=False, key=lambda x: x.start)
    pi, ri = 0, 0
    while pi < len(prediction) and ri < len(reference):
        span_pred, span_ref = prediction[pi], reference[ri]
        if span_pred == span_ref:
            pi += 1
            ri += 1
            tp += 1
        elif span_pred.start < span_ref.start:
            pi += 1
            fp += 1
        else:
            ri += 1
            fn += 1

    fp += len(prediction) - pi
    fn += len(reference) - ri
    
    return tp, fp, fn

In [6]:
class ConfusionMatrix:
    def __init__(self, tp=0, fp=0, fn=0):
        self.tp = tp
        self.fp = fp
        self.fn = fn
        self.precision, self.recall, self.f1 = 0, 0, 0
        
    def __add__(self, other):
        tp = self.tp + other.tp
        fp = self.fp + other.fp
        fn = self.fn + other.fn
        return ConfusionMatrix(tp, fp, fn)

    def compute(self):
        precision = self.tp / (self.tp + self.fp) if self.tp else 0
        recall = self.tp / (self.tp + self.fn) if self.tp else 0
        f1 = 2 * precision * recall / (precision + recall) if self.tp else 0
        self.precision, self.recall, self.f1 = precision, recall, f1

    def __repr__(self):
        self.compute()

        return (
            f' tp: {self.tp}\n'
            f' fp: {self.fp}\n'
            f' fn: {self.fn}\n'
            f' precsion: {self.precision*100:.2f}\n'
            f' recall: {self.recall*100:.2f}\n'
            f' f1: {self.f1*100:.2f}\n\n'
        )

In [7]:
ebm_nlp = load_dataset(
    'json',
    data_files = {
        'train': os.path.join(INPUT_FOLDER, 'test_pico_spans.json'), # not used
        'validation': os.path.join(INPUT_FOLDER, 'test_pico_spans.json'), # not used
        'test': os.path.join(INPUT_FOLDER, 'test_pico_spans.json')
    }
)

Found cached dataset json (/home/gzhang/.cache/huggingface/datasets/json/default-deaedfaba8add16c/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
val = ebm_nlp['test']

In [9]:
participants_metric = ConfusionMatrix()
interventions_metric = ConfusionMatrix()
outcomes_metric = ConfusionMatrix()
all_metric = ConfusionMatrix()

progress_bar = tqdm(range(len(val)))
for i in range(len(val)):
    original_labels = val['original_labels'][i]
    result_dict = val['pico_elements'][i]
    for pico_type in list(PicoType):
        records = result_dict[pico_type.name]
        prediction = [
            Span(start=r['span_start'], length =r['span_length'])
            for r in records
            if r['confidence'] > 0.5
        ] if records else []
        reference = extract_spans_from_labels(original_labels, pico_type.value)
        tp, fp, fn = eval_pred_single_sample(prediction, reference)
        batch_result = ConfusionMatrix(tp, fp, fn)
        all_metric += batch_result
        if pico_type == PicoType.PARTICIPANTS:
            participants_metric += batch_result
        if pico_type == PicoType.INTERVENTIONS:
            interventions_metric += batch_result
        if pico_type == PicoType.OUTCOMES:
            outcomes_metric += batch_result       
    progress_bar.update(1)

  0%|          | 0/2042 [00:00<?, ?it/s]

In [10]:
# baseline:         p .497, r .412, f1 .450
all_metric

 tp: 2038.0
 fp: 1772.0
 fn: 2164.0
 precsion: 53.49
 recall: 48.50
 f1: 50.87


In [9]:
BASELINE_FOLDER = 'baseline'

In [10]:
baseline = load_dataset(
    'json',
    data_files = {
        'train': os.path.join(BASELINE_FOLDER, 'test_baseline_pred.json'), # not used
        'validation': os.path.join(BASELINE_FOLDER, 'test_baseline_pred.json'),
        'test': os.path.join(BASELINE_FOLDER, 'test_baseline_pred.json')
    }
)

Found cached dataset json (/home/gzhang/.cache/huggingface/datasets/json/default-3e98f4a6402f7410/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/3 [00:00<?, ?it/s]

In [11]:
baseline_val = baseline['test']

In [14]:
baseline_participants_metric = ConfusionMatrix()
baseline_interventions_metric = ConfusionMatrix()
baseline_outcomes_metric = ConfusionMatrix()
baseline_all_metric = ConfusionMatrix()

progress_bar = tqdm(range(len(val)))
for i in range(len(val)):
    reference = baseline_val['original_labels'][i]
    prediction = baseline_val['pico_pred'][i]
    for pico_type in list(PicoType):
        records = result_dict[pico_type.name]
        pred = extract_spans_from_labels(prediction, pico_type.value)
        ref = extract_spans_from_labels(reference, pico_type.value)
        tp, fp, fn = eval_pred_single_sample(pred, ref)
        batch_result = ConfusionMatrix(tp, fp, fn)
        baseline_all_metric += batch_result
        if pico_type == PicoType.PARTICIPANTS:
            baseline_participants_metric += batch_result
        if pico_type == PicoType.INTERVENTIONS:
            baseline_interventions_metric += batch_result
        if pico_type == PicoType.OUTCOMES:
            baseline_outcomes_metric += batch_result       
    progress_bar.update(1)

  0%|          | 0/2042 [00:00<?, ?it/s]

In [15]:
baseline_all_metric

 tp: 1731.0
 fp: 1752.0
 fn: 2471.0
 precsion: 49.70
 recall: 41.19
 f1: 45.05
