In [None]:
import json
import pickle
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast

from data import BertNENDataset
from model import BiEncoder

In [None]:
# Load EMRs

emrs_path = Path("/nfs/nas-7.1/ckwu/datasets/emr/6000/emrs_with_annots.pickle")
emrs = pickle.loads(emrs_path.read_bytes())

In [None]:
# Load NER spans tuples

ner_spans_tuples = Path("/nfs/nas-7.1/ckwu/datasets/emr/6000/ner_spans_tuples.pickle")
ner_spans_l = pickle.loads(ner_spans_tuples.read_bytes())

In [None]:
# Load Mention to CUI

sm2cui_path = Path("/nfs/nas-7.1/ckwu/datasets/nen/data/single_mention2cui.json")
sm2cui = json.loads(sm2cui_path.read_bytes())

# Load CUI to preferred name
smcui2name_path = Path("/nfs/nas-7.1/ckwu/datasets/umls/smcui2name.json")
cui2name = json.loads(smcui2name_path.read_bytes())

In [None]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

dataset = BertNENDataset(emrs, ner_spans_l, sm2cui, cui2name, cui_batch_size=16, tokenizer=tokenizer)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True, collate_fn=lambda batch: batch[0])

In [None]:
sample = next(iter(dataloader))
emr_be, token_indices_l, cuis, negative_cuis_l = sample

In [None]:
cui_idx = 8

cui = cuis[cui_idx]
negative_cuis = negative_cuis_l[cui_idx]

ents_be = dataset.make_entities_be(cuis=[cui] + negative_cuis)
ents_labels = dataset.make_entities_labels(target_cui=cui, negative_cuis=negative_cuis)

In [None]:
model = BiEncoder(encoder_name="bert-base-uncased")

In [None]:
y_ents = model.encode_entities(ents_be)

In [None]:
mentions = model.encode_mentions(emr_be, token_indices_l)
assert len(mentions) == len(token_indices_l) == len(cuis) == len(negative_cuis_l)
y_ment = mentions[cui_idx]

In [None]:
scores = model.calc_scores(y_ment, y_ents)
loss = model.calc_loss(scores.squeeze(), ents_labels)

In [None]:
scores, loss

In [None]:
emr_idx = 1

emr = emrs[emr_idx]
spans = spans_tuples[emr_idx]
be = tokenizer(emr, return_offsets_mapping=True)
offsets = be.pop("offset_mapping")
token_indices_l = spans_to_token_indices_l(spans, offsets)

for token_indices in token_indices_l:
    mention = tokenizer.decode([be["input_ids"][token_idx] for token_idx in token_indices])
    print(mention)

In [None]:
mentions = model.encode_mentions(be.convert_to_tensors("pt", prepend_batch_axis=True), token_indices_l)