## Import Packages

In [None]:
import gc
import sys
sys.path.append("/nfs/nas-7.1/ckwu/mtl-icda-ht")

import json
import pickle
from tqdm import tqdm
from pathlib import Path
from argparse import Namespace

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from data import BertNENDataset, KBEntities
from model import BiEncoder
from test import fullset_evaluate
from utilities.data import split_by_div
from utilities.model import encoder_names_mapping
from utilities.utils import set_seeds, render_exp_name, move_bert_input_to_device

## Configuration

In [None]:
# config = json.loads(Path("./config.json").read_bytes())
# args = Namespace(**config)

# args.save_dir = Path(args.save_dir)
# args.exp_name = render_exp_name(args, hparams=["encoder", "seed", "emrbs", "cuibs", "optimizer", "lr", "nepochs", "fold", "remainder"])
# args.ckpt_path = args.save_dir / args.exp_name
# args.ckpt_path.mkdir(parents=True, exist_ok=True)
# (args.ckpt_path / "args.pickle").write_bytes(pickle.dumps(args))

# Load args of trained model
args = pickle.loads(Path("/nfs/nas-7.1/ckwu/mtl-icda-ht/components_testing/nen/models/encoder-BERT_seed-42_emrbs-1_cuibs-16_optimizer-Adam_lr-1e-05_nepochs-5_fold-10_remainder-0/args.pickle").read_bytes())

set_seeds(args.seed)

## Data

### Load Data

In [None]:
emrs = pickle.loads(Path(args.emr_path).read_bytes())
ner_spans_l = pickle.loads(Path(args.ner_spans_l_path).read_bytes())
sm2cui = json.loads(Path(args.sm2cui_path).read_bytes())
smcui2name = json.loads(Path(args.smcui2name_path).read_bytes())

### Train / Valid Split

In [None]:
train_emrs, train_ner_spans_l = [split_by_div(data, fold=args.fold, remainder=args.remainder, mode="train") for data in [emrs, ner_spans_l]]
valid_emrs, valid_ner_spans_l = [split_by_div(data, fold=args.fold, remainder=args.remainder, mode="valid") for data in [emrs, ner_spans_l]]

### Construct Dataset & DataLoader

In [None]:
tokenizer = AutoTokenizer.from_pretrained(encoder_names_mapping[args.tokenizer])

train_set = BertNENDataset(
    emrs=train_emrs,
    ner_spans_l=train_ner_spans_l,
    mention2cui=sm2cui,
    cui2name=smcui2name,
    cui_batch_size=args.cuibs,
    tokenizer=tokenizer
)
valid_set = BertNENDataset(
    emrs=valid_emrs,
    ner_spans_l=valid_ner_spans_l,
    mention2cui=sm2cui,
    cui2name=smcui2name,
    cui_batch_size=args.cuibs,
    tokenizer=tokenizer    
)
entities_set = KBEntities(
    id2desc=smcui2name,
    tokenizer=tokenizer
)

train_loader = DataLoader(train_set, batch_size=args.emrbs, shuffle=True, pin_memory=True, collate_fn=lambda batch: batch[0])
valid_loader = DataLoader(valid_set, batch_size=args.emrbs, shuffle=False, pin_memory=True, collate_fn=lambda batch: batch[0])
entities_loader = DataLoader(entities_set, batch_size=args.cuibs, shuffle=False, pin_memory=True, collate_fn=entities_set.collate_fn)

## Model, Optimizer, and Scheduler

In [None]:
model = BiEncoder(encoder_name=encoder_names_mapping[args.encoder]).to(args.device)
# optimizer = getattr(torch.optim, args.optimizer)(model.parameters(), lr=args.lr)

## Optimization

In [None]:
nsteps = 0
stale = 0
best_fullset_acc = 0.0

for epoch in range(1, args.nepochs + 1):
    print(f"\n===== Start training at epoch {epoch} =====\n")
    pbar = tqdm(total=len(train_loader), ncols=0, desc="Train", unit=" steps")

    for emr_be, mention_indices_l, target_cuis, negative_cuis_l in train_loader:
        model.train()
        
        # Encode mentions
        emr_be = move_bert_input_to_device(emr_be, args.device)

        mentions = model.encode_mentions(emr_be, mention_indices_l)
        assert len(mentions) == len(mention_indices_l) == len(target_cuis) == len(negative_cuis_l)

        # Encode entities
        emr_loss = torch.tensor([0.0]).to(args.device)
        for mention, target_cui, negative_cuis in zip(mentions, target_cuis, negative_cuis_l):
            batch_cuis = [target_cui] + negative_cuis; assert len(batch_cuis) == args.cuibs
            ents_be = train_set.make_entities_be(cuis=batch_cuis).to(args.device)
            ents_labels = train_set.make_entities_labels(target_cui, negative_cuis).to(args.device)

            y_ment = mention
            y_ents = model.encode_entities(ents_be)

            # Calculate score & loss
            scores = model.calc_scores(y_ment, y_ents)
            loss = model.calc_loss(scores.squeeze(), ents_labels)

            # Accumulate loss
            emr_loss += loss
        
        # Update parameters
        optimizer.zero_grad()
        if emr_loss.requires_grad:
            emr_loss.backward()
        optimizer.step()
        
        # Evaluate every k steps
        if nsteps % args.ckpt_steps == 0:
            fullset_acc = fullset_evaluate(valid_loader, model, args, entities_loader=entities_loader)
            print(f"Model evaluated at step {nsteps}: accuracy = {fullset_acc:.3f}")
            if fullset_acc > best_fullset_acc:
                stale = 0
                best_fullset_acc = fullset_acc
                print("Saving best model.")
                torch.save(model.state_dict(), args.ckpt_path / "best_model.ckpt")
            else:
                stale += 1
                print(f"Model stop improving for {int(stale * args.ckpt_steps)} steps.")
                if stale >= args.patience:
                    print("Stop training because of early stopping.")


        mean_loss = (emr_loss.detach().cpu().item() / len(mentions)) if len(mentions) > 0 else emr_loss.detach().cpu().item()
        nsteps += 1
        pbar.update(n=1)
        pbar.set_postfix(
            loss=f"{mean_loss:.3f}",
            acc=f"{fullset_acc:.3f}",
            steps=nsteps
        )

        del emr_loss, mean_loss
        gc.collect()

    pbar.close()

## Evaluation

In [None]:
# Load trained model
model.load_state_dict(torch.load(args.ckpt_path / "best_model.ckpt"))
model.eval()
None

In [None]:
entity_embeddings = torch.load(args.ckpt_path / "entity_embeddings_5454.pt")

### Full Set Accuracy

In [None]:
fullset_acc, batch_exec_times = fullset_evaluate(valid_loader, model, args, entity_embeddings, entities_loader=entities_loader)
fullset_acc, sum(batch_exec_times)

### Execution Time Analysis

In [None]:
import numpy as np

bets = np.array(batch_exec_times)

bets.mean(), bets.std()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

sns.histplot(data=bets, stat="density", kde=True)

### Qualitative Evaluation

In [None]:
from typing import Dict, List

# Variables
data_loader = valid_loader
all_y_ents = entity_embeddings.to(args.device)
all_cuis = entities_set._ids
all_descs = entities_set._descs

# Evaluation
# quantitative
total_correct = 0
total_predict = 0
# qualitative
ment2ent_count: Dict[str, Dict[str, int]] = dict()
ent2ment_count: Dict[str, Dict[str, int]] = dict()

def update_dicts(
        ment2ent: Dict[str, Dict[str, int]],
        ent2ment: Dict[str, Dict[str, int]], 
        mention_surfs: List[str],
        pred_descs: List[str]
    ) -> None:
    assert len(mention_surfs) == len(pred_descs)

    for mention_surf, pred_desc in zip(mention_surfs, pred_descs):
        ment = mention_surf
        ent = pred_desc
        
        if ment not in ment2ent:
            ment2ent[ment] = dict()
        if ent not in ent2ment:
            ent2ment[ent] = dict()
        
        ment2ent[ment][ent] = ment2ent[ment].get(ent, 0) + 1
        ent2ment[ent][ment] = ent2ment[ent].get(ment, 0) + 1

    return None

model.eval()
for emr_be, mention_indices_l, target_cuis, _ in data_loader: # No need of negative_cuis_l
    # get surface forms of mentions
    mention_ids_l = [emr_be["input_ids"][0][mention_indices] for mention_indices in mention_indices_l] # mention token IDs
    mention_surfs = [tokenizer.decode(mention_ids) for mention_ids in mention_ids_l] # mention surface forms

    pred_descs = list()
    with torch.no_grad():
        emr_be = move_bert_input_to_device(emr_be, args.device)
        y_ments = model.encode_mentions(emr_be, mention_indices_l)
        assert len(y_ments) == len(mention_indices_l) == len(target_cuis)

        scores = model.calc_scores(y_ments, all_y_ents)

        preds = scores.argmax(dim=-1).cpu().tolist()
        
        for pred, target_cui in zip(preds, target_cuis):
            pred_cui = all_cuis[pred]
            pred_desc = all_descs[pred]
            pred_descs.append(pred_desc)

            if pred_cui == target_cui:
                total_correct += 1
        
        total_predict += len(preds)
    
    # Update stats of 2 dicts
    update_dicts(ment2ent_count, ent2ment_count, mention_surfs, pred_descs)

In [None]:
m2e_path = Path(args.ckpt_path / "ment2ent_count.json")
m2e_path.write_text(json.dumps(ment2ent_count))

In [None]:
e2m_path = Path(args.ckpt_path / "ent2ment_count.json")
e2m_path.write_text(json.dumps(ent2ment_count))

### Previous Evaluation Results

In [None]:
untrained_bert_fullset_acc = 0.001633605600933489
untrained_linkbert_fullset_acc = 0.00023380874444704232
untrained_biobert_fullset_acc = 0.11421557166238017
untrained_clinicalbert_fullset_acc = 0.0911854103343465
trained_bert_fullset_acc = 0.8989498249708284

In [None]:
untrained_bert_loss, untrained_bert_acc = (17.59189080480663, 0.1294049008168028)
untrained_biobert_loss, untrained_biobert_acc = (86.79308684462555, 0.7194295066635492)
untrained_clinicalbert_loss, untrained_clinicalbert_acc = (103.7931861654464, 0.6738368014963759)
trained_loss, trained_acc = (0.010365398921037055, 0.9957992998833138)