In [1]:
from collections import defaultdict
from datasets import load_dataset, Sequence, ClassLabel
from enum import Enum
from datetime import datetime, timedelta
from pytz import timezone
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification

import evaluate
import json
import numpy as np
import os
import torch

In [2]:
BASE_MODEL = 'microsoft/BiomedNLP-PubMedBERT-large-uncased-abstract'
SPAN_CLF_MODEL_PATH = 'pico_span/span_clf/span-clf-PICO_NER-ebm_nlp_bioc-2023_06_02_05_59_04_EDT/checkpoint-3020'
# SPAN_CLF_MODEL_PATH = 'pico_span/span_clf/span-clf-PICO_NER-ebm_nlp_bioc-entity_only'
INPUT_FOLDER = 'data/bioc/json/step_1_boundary_pred'
OUTPUT_PATH = 'data/bioc/json/step_2_span_clf'

class DatasetSplit(Enum):
    train = 0
    validation = 1
    test = 2
    
class PicoType(Enum):
    PARTICIPANTS = 4
    INTERVENTIONS = 2
    OUTCOMES = 1

In [3]:
PICO_CLASSES = [
    'PARTICIPANTS', 'INTERVENTIONS', 'OUTCOMES',
]

id2label = {i: label for i, label in enumerate(PICO_CLASSES)}
label2id = {v: k for k, v in id2label.items()}

In [4]:
ebm_nlp = load_dataset(
    'json',
    data_files = {
        'train': os.path.join(INPUT_FOLDER, 'test_boundary_pred.json'), # not used
        'validation': os.path.join(INPUT_FOLDER, 'test_boundary_pred.json'),
        'test': os.path.join(INPUT_FOLDER, 'test_boundary_pred.json')
    }
)

Found cached dataset json (/home/gzhang/.cache/huggingface/datasets/json/default-59f01fb16cb7545b/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
ebm_nlp

DatasetDict({
    train: Dataset({
        features: ['pmid', 'tokens', 'original_labels', 'boundary_pred', 'start_confidence', 'end_confidence'],
        num_rows: 2042
    })
    validation: Dataset({
        features: ['pmid', 'tokens', 'original_labels', 'boundary_pred', 'start_confidence', 'end_confidence'],
        num_rows: 2042
    })
    test: Dataset({
        features: ['pmid', 'tokens', 'original_labels', 'boundary_pred', 'start_confidence', 'end_confidence'],
        num_rows: 2042
    })
})

In [6]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForSequenceClassification.from_pretrained(SPAN_CLF_MODEL_PATH)

In [7]:
def extract_span_boundaries(boundary_pred):
    starts = [i for i, p in enumerate(boundary_pred) if p & 1]
    ends = [i for i, p in enumerate(boundary_pred) if p & 2]
    candidates = []
    for s in starts:
        for e in ends:
            if s<=e:
                candidates.append((s, e))
    return candidates

In [8]:
class Span:
    def __init__(self, start, length):
        self.start = start
        self.length = length

    def __repr__(self):
        return f'Span(start={self.start}, length={self.length})'

    def __eq__(self, other):
        return (self.start, int(self.length)) == (other.start, int(other.length))

In [9]:
def text_span_iou(a, b):
    a_start, a_end = a.start, a.start + a.length
    b_start, b_end = b.start, b.start + b.length
    u = max(a_end, b_end) - min(a_start, b_start)
    if u == 0.0:
        return 0.0
    i = 0.0
    if a_start <= b_end and b_start <= a_end:
        i = min(a_end, b_end) - max(a_start, b_start)
    return i/u

def nms_span(spans, span_class, confidence, iou_threshold=0.0):
    keep = [True for _ in spans]
    for i in range(len(spans) - 1):
        if not keep[i]:
            continue
        for j in range(i+1, len(spans)):
            if not keep[j] or span_class[i] != span_class[j]:
                continue
            iou = text_span_iou(spans[i], spans[j])
            if iou > iou_threshold:
                #  if confidence[i] < confidence[j]:
                if spans[i].length > spans[j].length:
                    keep[i] = False
                else:
                    keep[j] = False
        
    return (
        [s for s, k in zip(spans, keep) if k],
        [c for c, k in zip(span_class, keep) if k],
        [c for c, k in zip(confidence, keep) if k],
    )

In [10]:
def extract_pico_elements_util(tokens, boundary_pred, tokenizer, model, threshold=0.5):
    candidate_spans = extract_span_boundaries(boundary_pred)
    spans = []
    pico_class = []
    confidence = []
    for span in candidate_spans:
        start, end = span
        content = tokens[start: end+1]
        x = tokenizer(content, padding=True, return_tensors='pt', is_split_into_words=True)
        y = model(**x)
        probability = np.squeeze(
            torch.nn.functional.sigmoid(y.logits).detach().numpy()
        ).tolist()
        for i, p in enumerate(probability):
            if p < threshold:
                continue
            pico_class.append(model.config.id2label[i])
            confidence.append(p)
            spans.append(Span(start=start, length=len(content)))
    return nms_span(spans, pico_class, confidence)

In [11]:
val = ebm_nlp['validation']
tokens, boundary_pred = val['tokens'][0], val['boundary_pred'][0]
spans, pico_class, confidence = extract_pico_elements_util(tokens, boundary_pred, tokenizer, model)

print(val['original_labels'][0])
print(boundary_pred)
for s, c, conf in zip(spans, pico_class, confidence):
    print(s, c, conf)

[0, 0, 2, 2, 0, 2, 2, 0, 0, 0, 0, 4, 4, 4, 4]
[0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 0, 1, 0, 0, 2]
Span(start=2, length=2) INTERVENTIONS 0.9953840374946594
Span(start=5, length=2) INTERVENTIONS 0.9956023693084717
Span(start=11, length=4) PARTICIPANTS 0.8516499996185303


In [12]:
tokens

['Comparison',
 'of',
 'budesonide',
 'Turbuhaler',
 'with',
 'budesonide',
 'aqua',
 'in',
 'the',
 'treatment',
 'of',
 'seasonal',
 'allergic',
 'rhinitis',
 '.']

In [13]:
def extract_pico_elements(dataset_dict, dataset_split, output_path, model, tokenizer, threshold=0.5):
    output_file = os.path.join(output_path, '{}_pico_spans.json'.format(dataset_split.name))
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        
    dataset = dataset_dict[dataset_split.name]
    progress_bar = tqdm(range(len(dataset)))
    with open(output_file, 'w+') as fout:
        for i in range(len(dataset)):
            row = {}
            row['pmid'] = dataset['pmid'][i]
            row['tokens'] = dataset['tokens'][i]
            row['original_labels'] = dataset['original_labels'][i]
            row['boundary_pred'] = dataset['boundary_pred'][i]
            start_confidence = dataset['start_confidence'][i]
            end_confidence = dataset['end_confidence'][i]
            spans, pico_class, confidence = extract_pico_elements_util(
                row['tokens'], row['boundary_pred'],
                tokenizer, model,
                threshold=threshold)
            row['pico_elements'] = {}
            for s, c, conf in zip(spans, pico_class, confidence):
                if c not in row['pico_elements']:
                    row['pico_elements'][c] = []
                span_dict = {}
                span_dict['span_start'] = s.start
                span_dict['span_length'] = s.length
                span_dict['confidence'] = conf
                row['pico_elements'][c].append(span_dict)
                row['start_confidence'] = start_confidence[s.start]
                row['end_confidence'] = end_confidence[s.start + s.length - 1]
            fout.write('{}\n'.format(json.dumps(row)))
            progress_bar.update(1)

In [14]:
extract_pico_elements(ebm_nlp, DatasetSplit.test, OUTPUT_PATH, model, tokenizer, threshold=0.5)

  0%|          | 0/2042 [00:00<?, ?it/s]

In [18]:
# OUTPUT_PATH_t = 'data/bioc/json/step_2_span_clf/threshold'

# def extract_pico_elements_exp(b_threshold, output_path, model, tokenizer, threshold=0.5):
#     output_file = os.path.join(output_path, f'{b_threshold:.2f}_pico_spans.json')
#     if not os.path.exists(output_path):
#         os.makedirs(output_path)
        
#     input_folder = 'data/bioc/json/step_1_boundary_pred/threshold'
#     dataset_dict = load_dataset(
#         'json',
#         data_files = {
#             'train': os.path.join(input_folder, f'{b_threshold:.2f}_boundary_pred.json'), # not used
#             'validation': os.path.join(input_folder, f'{b_threshold:.2f}_boundary_pred.json'),
#             'test': os.path.join(input_folder, f'{b_threshold:.2f}_boundary_pred.json')
#         }
#     )
        
#     dataset = dataset_dict['test']
#     progress_bar = tqdm(range(len(dataset)))
#     with open(output_file, 'w+') as fout:
#         for i in range(len(dataset)):
#             row = {}
#             row['pmid'] = dataset['pmid'][i]
#             row['tokens'] = dataset['tokens'][i]
#             row['original_labels'] = dataset['original_labels'][i]
#             row['boundary_pred'] = dataset['boundary_pred'][i]
#             start_confidence = dataset['start_confidence'][i]
#             end_confidence = dataset['end_confidence'][i]
#             spans, pico_class, confidence = extract_pico_elements_util(
#                 row['tokens'], row['boundary_pred'],
#                 tokenizer, model,
#                 threshold=threshold)
#             row['pico_elements'] = {}
#             for s, c, conf in zip(spans, pico_class, confidence):
#                 if c not in row['pico_elements']:
#                     row['pico_elements'][c] = []
#                 span_dict = {}
#                 span_dict['span_start'] = s.start
#                 span_dict['span_length'] = s.length
#                 span_dict['confidence'] = conf
#                 row['pico_elements'][c].append(span_dict)
#                 row['start_confidence'] = start_confidence[s.start]
#                 row['end_confidence'] = end_confidence[s.start + s.length - 1]
#             fout.write('{}\n'.format(json.dumps(row)))
#             progress_bar.update(1)

In [None]:
# extract_pico_elements_exp(0.2, OUTPUT_PATH_t, model, tokenizer, threshold=0.5)

Downloading and preparing dataset json/default to /home/gzhang/.cache/huggingface/datasets/json/default-4501c2141bd1f5c2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /home/gzhang/.cache/huggingface/datasets/json/default-4501c2141bd1f5c2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/2042 [00:00<?, ?it/s]

In [None]:
# extract_pico_elements_exp(0.25, OUTPUT_PATH_t, model, tokenizer, threshold=0.5)

In [None]:
# extract_pico_elements_exp(0.3, OUTPUT_PATH_t, model, tokenizer, threshold=0.5)

In [None]:
# extract_pico_elements_exp(0.35, OUTPUT_PATH_t, model, tokenizer, threshold=0.5)

In [None]:
# extract_pico_elements_exp(0.4, OUTPUT_PATH_t, model, tokenizer, threshold=0.5)

In [None]:
# extract_pico_elements_exp(0.45, OUTPUT_PATH_t, model, tokenizer, threshold=0.5)

In [None]:
# extract_pico_elements_exp(0.5, OUTPUT_PATH_t, model, tokenizer, threshold=0.5)