## Import and functions

In [1]:
import os 
os.environ
os.environ['PYTHONUNBUFFERED'] = '1'
os.environ['CUDA_VISIBLE_DEVIC'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['INHERIT_BERT'] = '1'

In [2]:
import argparse
import logging
import pickle
from typing import List

# from tensorboardX import SummaryWriter
import torch
import transformers
from tqdm import tqdm

from modeling.modeling_dragon import DRAGON
from utils.data_utils import Custom_DataLoader, DRAGON_DataLoader, simple_convert_examples_to_features, InputExample, \
    MODEL_NAME_TO_CLASS, InputFeatures

try:
    from transformers import (ConstantLRSchedule, WarmupLinearSchedule, WarmupConstantSchedule, BertTokenizer,
                              AlbertTokenizer, XLNetTokenizer, RobertaTokenizer)
except:
    from transformers import get_constant_schedule, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup, \
        BertTokenizer
import wandb
from transformers import (OpenAIGPTTokenizer, BertTokenizer, BertTokenizerFast, XLNetTokenizer, RobertaTokenizer,
                          RobertaTokenizerFast)
from modeling import modeling_dragon
from utils import utils

import numpy as np

import socket, os, sys, subprocess

logger = logging.getLogger(__name__)


class Custom_DataLoader(DRAGON_DataLoader):
    def __init__(self, question: str, answers: List[str], args, train_statement_path, train_adj_path,
                 dev_statement_path, dev_adj_path,
                 test_statement_path, test_adj_path,
                 batch_size, eval_batch_size, device, model_name, max_node_num=200, max_seq_length=128,
                 is_inhouse=False, inhouse_train_qids_path=None,
                 subsample=1.0, n_train=-1, debug=False, cxt_node_connects_all=False, kg="cpnet"):
        self.args = args
        self.batch_size = batch_size
        self.eval_batch_size = eval_batch_size
        self.device0, self.device1 = device
        self.is_inhouse = is_inhouse
        self.debug = debug
        self.model_name = model_name
        self.max_node_num = max_node_num
        self.debug_sample_size = 32
        self.cxt_node_connects_all = cxt_node_connects_all

        self.model_type = MODEL_NAME_TO_CLASS[model_name]
        self.load_resources(kg)

        # Load training data
        print('train_statement_path', train_statement_path)
        self.train_qids, self.train_labels, self.train_encoder_data, train_concepts_by_sents_list = self.load_input_tensors(
            question, answers, max_seq_length, mode='train')

        num_choice = self.train_encoder_data[0].size(1)
        self.num_choice = num_choice
        print('num_choice', num_choice)
        *self.train_decoder_data, self.train_adj_data = self.load_sparse_adj_data_with_contextnode(train_adj_path,
                                                                                                   max_node_num,
                                                                                                   train_concepts_by_sents_list,
                                                                                                   mode='train')

    def load_input_tensors(self, question, answers, max_seq_length, mode='eval'):
        """Construct input tensors for the LM component of the model."""

        if self.model_type in ('bert', 'xlnet', 'roberta', 'albert'):
            # input_tensors = load_bert_xlnet_roberta_input_tensors(input_jsonl_path, max_seq_length, self.debug, self.tokenizer, self.debug_sample_size)
            input_tensors = load_bert_xlnet_roberta_input_from_text(question, answers, max_seq_length, self.debug,
                                                                    self.tokenizer, self.debug_sample_size)
        else:
            raise ValueError

        if mode == 'train' and self.args.local_rank != -1:
            example_ids, all_label, data_tensors, concepts_by_sents_list = input_tensors  # concepts_by_sents_list is always []
            assert len(example_ids) == len(all_label) == len(data_tensors[0])
            total_num = len(data_tensors[0])
            rem = total_num % self.args.world_size
            if rem != 0:
                example_ids = example_ids + example_ids[:self.args.world_size - rem]
                all_label = torch.cat([all_label, all_label[:self.args.world_size - rem]], dim=0)
                data_tensors = [torch.cat([t, t[:self.args.world_size - rem]], dim=0) for t in data_tensors]
                total_num_aim = total_num + self.args.world_size - rem
            else:
                total_num_aim = total_num
            assert total_num_aim % self.args.world_size == 0
            assert total_num_aim == len(data_tensors[0])
            _select = (torch.arange(total_num_aim) % self.args.world_size) == self.args.local_rank  # bool tensor
            example_ids = np.array(example_ids)[_select].tolist()
            all_label = all_label[_select]
            data_tensors = [t[_select] for t in data_tensors]
            input_tensors = (example_ids, all_label, data_tensors, [])
        example_ids = input_tensors[0]
        print('local_rank', self.args.local_rank, 'len(example_ids)', len(example_ids), file=sys.stderr)
        return input_tensors


def load_bert_xlnet_roberta_input_from_text(
        contexts: str,  # = "A 23-year-old pregnant woman at 22 weeks gestation presents wit",
        answers: List[str],  # = ['Ampicillin', 'Ceftriaxone', 'Doxycycline', 'Nitrofurantoin'],
        max_seq_length, debug, tokenizer, label=3):
    def select_field(features, field):
        return [[choice[field] for choice in feature.choices_features] for feature in features]

    def convert_features_to_tensors(features):
        all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long)
        all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long)
        all_output_mask = torch.tensor(select_field(features, 'output_mask'), dtype=torch.bool)
        all_label = torch.tensor([f.label for f in features], dtype=torch.long)
        return all_input_ids, all_input_mask, all_segment_ids, all_output_mask, all_label

    examples = [InputExample(
        example_id="train-00000",
        contexts=[contexts] * len(answers),
        question="",
        endings=answers,
        label=label
    )]
    features, concepts_by_sents_list = simple_convert_examples_to_features(examples,
                                                                           list(range(len(examples[0].endings))),
                                                                           max_seq_length, tokenizer, debug)

    example_ids = [f.example_id for f in features]
    *data_tensors, all_label = convert_features_to_tensors(features)
    return example_ids, all_label, data_tensors, concepts_by_sents_list


def simple_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, debug=False):
    """ Loads a data file into a list of `InputBatch`s
        `cls_token_at_end` define the location of the CLS token:
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
    """
    label_map = {label: i for i, label in enumerate(label_list)}

    features = []
    concepts_by_sents_list = []
    for ex_index, example in tqdm(enumerate(examples), total=len(examples), desc="Converting examples to features"):
        if debug and ex_index >= 32:
            break
        choices_features = []
        for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
            ans = example.question + " " + ending

            encoded_input = tokenizer(context, ans, padding="max_length", truncation=True, max_length=max_seq_length,
                                      return_token_type_ids=True, return_special_tokens_mask=True)
            input_ids = encoded_input["input_ids"]
            output_mask = encoded_input["special_tokens_mask"]
            input_mask = encoded_input["attention_mask"]
            segment_ids = encoded_input["token_type_ids"]

            assert len(input_ids) == max_seq_length
            assert len(output_mask) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            choices_features.append((input_ids, input_mask, segment_ids, output_mask))
        label = label_map.get(example.label, -100)
        features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label))

    return features, concepts_by_sents_list


def construct_model(args, kg, dataset_final_num_relation=100):
    ########################################################
    #   Load pretrained concept embeddings
    ########################################################
    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = np.concatenate(cp_emb, 1)
    cp_emb = torch.tensor(cp_emb, dtype=torch.float)

    concept_num, concept_in_dim = cp_emb.size(0), cp_emb.size(1)
    print('| num_concepts: {} |'.format(concept_num))
    if args.random_ent_emb:
        cp_emb = None
        freeze_ent_emb = False
        concept_in_dim = args.gnn_dim
    else:
        freeze_ent_emb = args.freeze_ent_emb

    ##########################################################
    #   Build model
    ##########################################################

    if kg == "cpnet":
        n_ntype = 4
        n_etype = 38
        # assert n_etype == dataset.final_num_relation *2
    elif kg == "ddb":
        n_ntype = 4
        n_etype = 34
        # assert n_etype == dataset.final_num_relation *2
    elif kg == "umls":
        n_ntype = 4
        n_etype = dataset_final_num_relation * 2
        # print('final_num_relation', dataset.final_num_relation, 'len(id2relation)', len(dataset.id2relation))
        # print('final_num_relation', dataset.final_num_relation, 'len(id2relation)', len(dataset.id2relation),
        #       file=sys.stderr)
    else:
        raise ValueError("Invalid KG.")
    if args.cxt_node_connects_all:
        n_etype += 2
    print('n_ntype', n_ntype, 'n_etype', n_etype)
    print('n_ntype', n_ntype, 'n_etype', n_etype, file=sys.stderr)
    encoder_load_path = args.encoder_load_path if args.encoder_load_path else args.encoder
    model = modeling_dragon.DRAGON(args, encoder_load_path, k=args.k, n_ntype=n_ntype, n_etype=n_etype,
                                   n_concept=concept_num,
                                   concept_dim=args.gnn_dim,
                                   concept_in_dim=concept_in_dim,
                                   n_attention_head=args.att_head_num, fc_dim=args.fc_dim, n_fc_layer=args.fc_layer_num,
                                   p_emb=args.dropouti, p_gnn=args.dropoutg, p_fc=args.dropoutf,
                                   pretrained_concept_emb=cp_emb, freeze_ent_emb=freeze_ent_emb,
                                   init_range=args.init_range, ie_dim=args.ie_dim, info_exchange=args.info_exchange,
                                   ie_layer_num=args.ie_layer_num, sep_ie_layers=args.sep_ie_layers,
                                   layer_id=args.encoder_layer)
    return model


def get_pred(question: str, answers: List[str], model: DRAGON):
    """Eval on the dev or test set - calculate loss and accuracy"""

    eval_set = Custom_DataLoader(question, answers, args, args.train_statements, args.train_adj,
                                 args.dev_statements, args.dev_adj,
                                 args.test_statements, args.test_adj,
                                 batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
                                 device=devices,
                                 model_name=args.encoder,
                                 max_node_num=args.max_node_num, max_seq_length=args.max_seq_len,
                                 is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
                                 subsample=args.subsample, n_train=args.n_train, debug=args.debug,
                                 cxt_node_connects_all=args.cxt_node_connects_all, kg=kg)  # .train()

    with torch.no_grad():
        for qids, labels, *input_data in tqdm(eval_set.train(), desc="Dev/Test batch"):
            logits, mlm_loss, link_losses = model(*input_data)
            predictions = logits.argmax(1)  # [bsize, ]
            for qid, pred in zip(qids, predictions):
                return '{},{}'.format(qid, chr(ord('A') + pred.item())), pred


def get_model(args, devices, kg) -> DRAGON:
    question = "A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?"
    answers = ["Nitrofurantoin", "Ampicillin", "Ceftriaxone", "Doxycycline"]
    assert args.load_model_path is not None
    load_model_path = args.load_model_path
    print("loading from checkpoint: {}".format(load_model_path))
    checkpoint = torch.load(load_model_path, map_location='cpu')

    train_statements = args.train_statements
    dev_statements = args.dev_statements
    test_statements = args.test_statements
    train_adj = args.train_adj
    dev_adj = args.dev_adj
    test_adj = args.test_adj
    inhouse = args.inhouse

    # args = utils.import_config(checkpoint["config"], args)
    args.train_statements = train_statements
    args.dev_statements = dev_statements
    args.test_statements = test_statements
    args.train_adj = train_adj
    args.dev_adj = dev_adj
    args.test_adj = test_adj
    args.inhouse = inhouse

    # dataset = Custom_DataLoader(question, answers, args, args.train_statements, args.train_adj,
    #                             args.dev_statements, args.dev_adj,
    #                             args.test_statements, args.test_adj,
    #                             batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
    #                             device=devices,
    #                             model_name=args.encoder,
    #                             max_node_num=args.max_node_num, max_seq_length=args.max_seq_len,
    #                             is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
    #                             subsample=args.subsample, n_train=args.n_train, debug=args.debug,
    #                             cxt_node_connects_all=args.cxt_node_connects_all, kg=kg)

    model = construct_model(args, kg)
    INHERIT_BERT = os.environ.get('INHERIT_BERT', 0)
    bert_or_roberta = model.lmgnn.bert if INHERIT_BERT else model.lmgnn.roberta
    try:
        tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer,
                           'albert': AlbertTokenizer}.get(MODEL_NAME_TO_CLASS[args.encoder])
    except:
        tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer}.get(
            MODEL_NAME_TO_CLASS[args.encoder])
    tokenizer = tokenizer_class.from_pretrained(args.encoder)

    bert_or_roberta.resize_token_embeddings(len(tokenizer))

    model.load_state_dict(checkpoint["model"], strict=False)

    model.to(devices[1])
    model.lmgnn.concept_emb.to(devices[0])
    model.eval()

    print('inhouse?', args.inhouse)

    print('args.train_statements', args.train_statements)
    print('args.dev_statements', args.dev_statements)
    print('args.test_statements', args.test_statements)
    print('args.train_adj', args.train_adj)
    print('args.dev_adj', args.dev_adj)
    print('args.test_adj', args.test_adj)

    return model


def get_devices(args):
    """Get the devices to put the data and the model based on whether to use GPUs and, if so, how many of them are available."""

    if args.local_rank == -1 or not args.cuda:
        if torch.cuda.device_count() >= 2 and args.cuda:
            device0 = torch.device("cuda:0")
            device1 = torch.device("cuda:1")
            print("device0: {}, device1: {}".format(device0, device1))
        elif torch.cuda.device_count() == 1 and args.cuda:
            device0 = torch.device("cuda:0")
            device1 = torch.device("cuda:0")
        else:
            device0 = torch.device("cpu")
            device1 = torch.device("cpu")
    else:
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device0 = torch.device("cuda", args.local_rank)
        device1 = device0
        torch.distributed.init_process_group(backend="nccl")

    args.world_size = world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1
    print("Process rank: %s, device: %s, distributed training: %s, world_size: %s" %
          (args.local_rank,
           device0,
           bool(args.local_rank != -1),
           world_size), file=sys.stderr)

    return device0, device1


logging.basicConfig(format='%(asctime)s,%(msecs)d %(levelname)-8s [%(name)s:%(funcName)s():%(lineno)d] %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.WARNING)

  from .autonotebook import tqdm as notebook_tqdm


PreTrainedModelClass <class 'transformers.models.bert.modeling_bert.BertPreTrainedModel'>
ModelClass <class 'transformers.models.bert.modeling_bert.BertModel'>


## Getting model

In [None]:
args = argparse.Namespace(att_head_num=2, batch_size=1, cuda=True, cxt_node_connects_all=False, data_dir='data',
                          data_loader_one_process_at_a_time=False, dataset='medqa', debug=True, decoder_lr=0.0001,
                          dev_adj='data/medqa/graph/dev.graph.adj.pk',
                          dev_statements='data/medqa/statement/dev.statement.jsonl', dropoutf=0.2, dropoutg=0.2,
                          dropouti=0.2, dump_graph_cache=True, encoder='michiyasunaga/BioLinkBERT-large',
                          encoder_layer=-1, encoder_load_path='', encoder_lr=2e-05, end_task=1.0,
                          ent_emb_paths=['umls/ent_emb_blbertL.npy'], eval_batch_size=2, eval_interval=5,
                          fc_dim=200, fc_layer_num=0, fp16=True, freeze_ent_emb=True, gnn_dim=200, ie_dim=400,
                          ie_layer_num=1, info_exchange=True, inhouse=False,
                          inhouse_train_qids='data/medqa/inhouse_split_qids.txt', init_range=0.02, k=5, kg='umls',
                          kg_only_use_qa_nodes=False, kg_vocab_path='umls/concepts.txt', link_decoder='DistMult',
                          link_drop_max_count=100, link_drop_probability=0.2,
                          link_drop_probability_in_which_keep=0.2, link_gamma=12,
                          link_negative_adversarial_sampling=True, link_negative_adversarial_sampling_temperature=1,
                          link_negative_sample_size=64, link_normalize_headtail=0, link_proj_headtail=False,
                          link_regularizer_weight=0.01, link_task=0.0, load_graph_cache=True,
                          load_model_path='models/medqa_model.pt', local_rank=-1, log_interval=1,
                          loss='cross_entropy', lr_schedule='warmup_linear', max_epochs_before_stop=100,
                          max_grad_norm=1.0, max_node_num=200, max_num_relation=-1, max_seq_len=512,
                          mini_batch_size=1, mlm_probability=0.15, mlm_task=0.0, mode='eval', n_epochs=30,
                          n_train=-1, no_node_score=True, optim='radam', random_ent_emb=False, redef_epoch_steps=-1,
                          refreeze_epoch=10000, residual_ie=2, resume_checkpoint='None', resume_id='None',
                          run_name='run1', save_dir='./saved_models/', save_model=0.0, scaled_distmult=False,
                          seed=22, sep_ie_layers=False, span_mask=False, subsample=1.0,
                          test_adj='data/medqa/graph/test.graph.adj.pk',
                          test_statements='data/medqa/statement/test.statement.jsonl',
                          train_adj='data/medqa/graph/train.graph.adj.pk',
                          train_statements='data/medqa/statement/train.statement.jsonl', unfreeze_epoch=0,
                          upcast=True, use_codalab=0, use_wandb=False, warmup_steps=500.0, weight_decay=0.01,
                          world_size=1)

devices = get_devices(args)

if not args.use_wandb:
    wandb_mode = "disabled"
elif args.debug:
    wandb_mode = "offline"
else:
    wandb_mode = "online"

# We can optionally resume training from a checkpoint. If doing so, also set the `resume_id` so that you resume your previous wandb run instead of creating a new one.
resume = args.resume_checkpoint not in [None, "None"]

args.hf_version = transformers.__version__

if args.local_rank in [-1, 0]:
    wandb_id = args.resume_id if resume and (args.resume_id not in [None, "None"]) else wandb.util.generate_id()
    args.wandb_id = wandb_id
    wandb.init(project="DRAGON", config=args, name=args.run_name, resume="allow", id=wandb_id,
               settings=wandb.Settings(start_method="fork"), mode=wandb_mode)
    print(socket.gethostname())
    print("pid:", os.getpid())
    print("conda env:", os.environ.get('CONDA_DEFAULT_ENV'))
    print("screen: %s" % subprocess.check_output('echo $STY', shell=True).decode('utf'))
    print("gpu: %s" % subprocess.check_output('echo $CUDA_VISIBLE_DEVICES', shell=True).decode('utf'))
    utils.print_cuda_info()
    print("wandb id: ", wandb_id)
    wandb.run.log_code('.')

kg = args.kg
if args.dataset == "medqa_usmle":
    kg = "ddb"
elif args.dataset in ["medqa", "pubmedqa", "bioasq"]:
    kg = "umls"
print("KG used:", kg)
print("KG used:", kg, file=sys.stderr)
model = get_model(args, devices, kg)

## Inference

In [4]:
question = "A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?"
answers = ["Ampicillin", "Ceftriaxone", "Doxycycline", "Nitrofurantoin"]
preds = get_pred(question, answers, model)
print(f"Answer is: {preds[1].item() + 1}")

train_statement_path data/medqa/statement/train.statement.jsonl


Converting examples to features: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 75.01it/s]

num_choice 4
Loading sparse adj data...
Loading cache data/medqa/graph/train.graph.adj.pk-nodenum200.loaded_cache



local_rank -1 len(example_ids) 1
Loading cache data/medqa/graph/train.graph.adj.pk-nodenum200.loaded_cache
Loaded cache data/medqa/graph/train.graph.adj.pk-nodenum200.loaded_cache
local_rank -1 len(edge_index) 10178
local_rank -1 len(train_indexes) 32 train_indexes[:10] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


| ori_adj_len: mu 297.77 sigma 264.97 | adj_len: 147.16 | prune_rate： 0.53 | qc_num: 28.09 | ac_num: 1.54 |
local_rank -1 len(train_indexes) 32 train_indexes[:10] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


Dev/Test batch:   0%|                                                                                                                                           | 0/32 [00:15<?, ?it/s]


Answer is: 4


In [5]:
question = "I feel tired in the morning after I wake up. I sleep 6 hours per day. How can I overcome this issue?"
answers = [
    "You should sleep 1 hour more per day.",
    "You should sleep 2 hours more per day.",
    "You should sleep 1 hour less per day.",
    "Your sleep duration is fine. Try to go to bed earlier."]
preds = get_pred(question, answers, model)
print(f"Answer is: {preds[1].item() + 1}")


train_statement_path data/medqa/statement/train.statement.jsonl


Converting examples to features: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 322.79it/s]

num_choice 4
Loading sparse adj data...
Loading cache data/medqa/graph/train.graph.adj.pk-nodenum200.loaded_cache



local_rank -1 len(example_ids) 1
Loading cache data/medqa/graph/train.graph.adj.pk-nodenum200.loaded_cache
Loaded cache data/medqa/graph/train.graph.adj.pk-nodenum200.loaded_cache
local_rank -1 len(edge_index) 10178
local_rank -1 len(train_indexes) 32 train_indexes[:10] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


| ori_adj_len: mu 297.77 sigma 264.97 | adj_len: 147.16 | prune_rate： 0.53 | qc_num: 28.09 | ac_num: 1.54 |
local_rank -1 len(train_indexes) 32 train_indexes[:10] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


Dev/Test batch:   0%|                                                                                                                                           | 0/32 [00:15<?, ?it/s]


Answer is: 2
