## Import Packages

In [None]:
import warnings
import sys
sys.path.append("/nfs/nas-7.1/ckwu/mtl-icda-ht")

import numpy as np
import torch
import torch.nn as nn
import random
import json
import jsonlines
import pickle

from argparse import Namespace
from pathlib import Path
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast

from utilities.data import MedicalNERIOBDataset, split_by_div
from utilities.utils import set_seeds, move_bert_input_to_device, visualize_iob_labels
from utilities.model import BertNERModel, encoder_names_mapping

warnings.filterwarnings("ignore")

## Config

In [None]:
with open("./config.json") as f:
    config = json.load(f)

args = Namespace(**config)

set_seeds(config["seed"])
assert torch.cuda.is_available()

## Data

In [None]:
# load data
# x: EMR
emr_path = Path(args.emr_path)
emrs = pickle.loads(emr_path.read_bytes())
# y: NER labels
spans_tuples_path = Path(args.ner_spans_tuples_path)
spans_tuples = pickle.loads(spans_tuples_path.read_bytes())

# train/val split
train_emrs, train_labels = [split_by_div(data, fold=args.fold, remainder=args.remainder, mode="train") for data in [emrs, spans_tuples]]
valid_emrs, valid_labels = [split_by_div(data, fold=args.fold, remainder=args.remainder, mode="valid") for data in [emrs, spans_tuples]]

tokenizer = BertTokenizerFast.from_pretrained(encoder_names_mapping[args.encoder])
train_set = MedicalNERIOBDataset(emrs=train_emrs, spans_tuples=train_labels, tokenizer=tokenizer)
valid_set = MedicalNERIOBDataset(emrs=valid_emrs, spans_tuples=valid_labels, tokenizer=tokenizer)

train_loader = DataLoader(train_set, batch_size=args.bs, shuffle=True, pin_memory=True, collate_fn=train_set.collate_fn)
valid_loader = DataLoader(valid_set, batch_size=args.bs, shuffle=False, pin_memory=True, collate_fn=valid_set.collate_fn)

## Model

In [None]:
# load model
model = BertNERModel(encoder=encoder_names_mapping[args.encoder], num_tags=train_set.num_tags).to(args.device)
model.load_state_dict(torch.load("/nfs/nas-7.1/ckwu/mtl-icda-ht/components_testing/ner/models/encoder-BioBERT_nepochs-5_bs-16_lr-4e-05_fold-10_remainder-0.pth", map_location=args.device))
# criterion = nn.CrossEntropyLoss(reduction="mean", ignore_index=-100)

In [None]:
record = trainer(train_loader, valid_loader, model, criterion, args)

In [None]:
# save evaluation results
with open("./eval_results/{}.json".format(config["model_save_name"]), "wt") as f:
    json.dump(record, f)

## Evaluate Model

In [None]:
x, y = next(iter(valid_loader))
x = move_bert_input_to_device(x, args.device)
y = y.to(args.device)

In [None]:
scores = model(x)

## Visualize Labels

In [None]:
idx = 9

input_ids = x["input_ids"][idx].tolist()
label_ids = scores.argmax(dim=-1)[idx].tolist()

visualize_iob_labels(tokenizer, input_ids, label_ids, train_set.idx2iob)