## Import Packages

In [None]:
import warnings
import sys
sys.path.append("/nfs/nas-7.1/ckwu/mtl-icda-ht")

import numpy as np
import torch
import torch.nn as nn
import random
import json
import pickle
import gc

from argparse import Namespace
from pathlib import Path
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from utilities.data import MedicalNERIOBDataset, split_by_div
from utilities.utils import set_seeds, load_config, render_exp_name, move_bert_input_to_device
from utilities.model import BertNERModel, encoder_names_mapping
from utilities.evaluation import visualize_iobpol_labels

## Config

In [None]:
config = load_config()
args = Namespace(**config)
hparams = ["encoder", "optimizer", "lr", "nepochs", "bs", "fold", "remainder"]
assert args.bs % args.grad_accum_steps == 0

args.exp_name = render_exp_name(args, hparams)
args.ckpt_path = Path(args.save_dir) / args.exp_name
args.ckpt_path.mkdir(parents=True, exist_ok=True)
(args.ckpt_path / "args.pickle").write_bytes(pickle.dumps(args))

set_seeds(args.seed)

## Data

In [None]:
emrs = pickle.loads(Path(args.emr_path).read_bytes())
ner_spans_l = pickle.loads(Path(args.ner_spans_l_path).read_bytes())

train_emrs, train_labels = [split_by_div(data, fold=args.fold, remainder=args.remainder, mode="train") for data in [emrs, ner_spans_l]]
valid_emrs, valid_labels = [split_by_div(data, fold=args.fold, remainder=args.remainder, mode="valid") for data in [emrs, ner_spans_l]]

tokenizer = AutoTokenizer.from_pretrained(encoder_names_mapping[args.encoder], use_fast=True)

### Developing IOB with Polarity Label Mapping

### Check Correctness of bert_offsets_to_iob_labels()

In [None]:
train_set = MedicalIOBPOLDataset(
    text_l=train_emrs,
    ner_spans_l=train_labels,
    tokenizer=tokenizer
)

valid_set = MedicalIOBPOLDataset(
    text_l=valid_emrs,
    ner_spans_l=valid_labels,
    tokenizer=tokenizer
)

In [None]:
from colorama import Fore, Style

label_color_mappings = {
    0: Style.RESET_ALL,
    1: Fore.GREEN,
    2: Fore.CYAN,
    3: Fore.RED,
    4: Fore.YELLOW,
    -100: Style.RESET_ALL
}

In [None]:
emr_idx = 12
text_be, iob_labels = valid_set[emr_idx]

visualize_iobpol_labels(tokenizer, text_be["input_ids"].tolist(), iob_labels, label_color_mappings)

In [22]:
train_loader = DataLoader(train_set, batch_size=args.bs, shuffle=True, pin_memory=True, collate_fn=train_set.collate_fn)
valid_loader = DataLoader(valid_set, batch_size=args.bs, shuffle=False, pin_memory=True, collate_fn=valid_set.collate_fn)

## Model

In [None]:
# load model
model = BertNERModel(encoder=encoder_names_mapping[args.encoder], num_tags=train_set.num_tags).to(args.device)
model.load_state_dict(torch.load("/nfs/nas-7.1/ckwu/mtl-icda-ht/components_testing/ner/models/encoder-BioBERT_nepochs-5_bs-16_lr-4e-05_fold-10_remainder-0.pth", map_location=args.device))
# criterion = nn.CrossEntropyLoss(reduction="mean", ignore_index=-100)

In [None]:
record = trainer(train_loader, valid_loader, model, criterion, args)

In [None]:
# save evaluation results
with open("./eval_results/{}.json".format(config["model_save_name"]), "wt") as f:
    json.dump(record, f)

## Evaluate Model

In [None]:
x, y = next(iter(valid_loader))
x = move_bert_input_to_device(x, args.device)
y = y.to(args.device)

In [None]:
scores = model(x)

## Visualize Labels

In [None]:
idx = 9

input_ids = x["input_ids"][idx].tolist()
label_ids = scores.argmax(dim=-1)[idx].tolist()

visualize_iob_labels(tokenizer, input_ids, label_ids, train_set.idx2iob)