In [17]:
import random

try:
    from transformers import (ConstantLRSchedule, WarmupLinearSchedule, WarmupConstantSchedule)
except:
    from transformers import get_constant_schedule, get_constant_schedule_with_warmup,  get_linear_schedule_with_warmup

from modeling.modeling_qagnn import *
from utils.optimization_utils import OPTIMIZER_CLASSES
from utils.parser_utils import *
import utils
from collections import defaultdict, OrderedDict
import numpy as np

import socket, os, subprocess, datetime

## LSTM Encoder

In [18]:
import modeling.modeling_encoder as enc

encoder = enc.TextEncoder('lstm', vocab_size=100, emb_size=100, hidden_size=200, num_layers=4)
input_ids = torch.randint(0, 100, (30, 70))
lenghts = torch.randint(1, 70, (30,))
outputs = encoder(input_ids, lenghts)
assert outputs[0].size() == (30, 200)
assert len(outputs[1]) == 4 + 1
assert all([x.size() == (30, 70, 100 if l == 0 else 200) for l, x in enumerate(outputs[1])])
print('all tests are passed')
input_ids[0]

all tests are passed


tensor([44,  1, 16, 62, 78,  9, 51, 26,  7, 71, 90, 57, 65, 52, 48, 78, 89, 58,
        87, 58, 26, 90, 35, 29,  7, 46, 89, 82, 45, 34, 39, 85,  4, 53, 79, 39,
        82, 64, 47, 70, 66, 76, 74, 57, 56, 36, 31, 88, 74, 75, 76, 14, 20, 88,
        39, 57, 58, 95, 16,  0, 12, 62, 26, 90, 61, 27, 48, 77, 33, 29])

## BERT Encoder

In [19]:
from modeling.modeling_encoder import TextEncoder, MODEL_NAME_TO_CLASS, MODEL_CLASS_TO_NAME
from utils.data_utils import *
from utils.layers import *
import torch.nn.functional as F

model_name = 'bert'
model_type = MODEL_CLASS_TO_NAME[model_name]

In [20]:
model_type

['bert-base-uncased',
 'bert-large-uncased',
 'bert-base-cased',
 'bert-large-cased',
 'bert-base-multilingual-uncased',
 'bert-base-multilingual-cased',
 'bert-base-chinese',
 'bert-base-german-cased',
 'bert-large-uncased-whole-word-masking',
 'bert-large-cased-whole-word-masking',
 'bert-large-uncased-whole-word-masking-finetuned-squad',
 'bert-large-cased-whole-word-masking-finetuned-squad',
 'bert-base-cased-finetuned-mrpc',
 'bert-base-german-dbmdz-cased',
 'bert-base-german-dbmdz-uncased',
 'cl-tohoku/bert-base-japanese',
 'cl-tohoku/bert-base-japanese-whole-word-masking',
 'cl-tohoku/bert-base-japanese-char',
 'cl-tohoku/bert-base-japanese-char-whole-word-masking',
 'TurkuNLP/bert-base-finnish-cased-v1',
 'TurkuNLP/bert-base-finnish-uncased-v1',
 'wietsedv/bert-base-dutch-cased']

In [21]:
model_name = 'bert-base-uncased'
encoder_config={}
encoder = TextEncoder(model_name, **encoder_config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
encoder;

### MedQA params

In [23]:
parser = get_parser()
dataset="medqa_usmle"
model='cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
ent_emb='ddb'
saved_models = "saved_models"
parser.set_defaults(mode="eval_detail")
parser.set_defaults(encoder=model)
parser.set_defaults(ent_emb=ent_emb)
parser.set_defaults(ent_emb_paths=[utils.parser_utils.EMB_PATHS['ddb']])
parser.set_defaults(load_model_path="saved_models/medqa_usmle_model_hf3.4.0.pt")

In [24]:
utils.parser_utils.EMB_PATHS['ddb']

'data/ddb/ent_emb.npy'

#### Set Params

In [25]:
parser.set_defaults(dataset=dataset)
parser.set_defaults(save_model=True)
parser.set_defaults(save_dir=saved_models)
parser.set_defaults(train_adj="data/{}/graph/dev.graph.adj.pk".format(dataset))
parser.set_defaults(dev_adj="data/{}/graph/dev.graph.adj.pk".format(dataset))
parser.set_defaults(test_adj="data/{}/graph/test.graph.adj.pk".format(dataset))
parser.set_defaults(train_statements="data/{}/statement/dev.statement.jsonl".format(dataset))
parser.set_defaults(dev_statements="data/{}/statement/dev.statement.jsonl".format(dataset))
parser.set_defaults(test_statements="data/{}/statement/test.statement.jsonl".format(dataset))
parser.add_argument('-ebs', '--eval_batch_size', default=2, type=int)
parser.add_argument('--subsample', default=1.0, type=float)
parser.add_argument('--use_cache', default=True, type=bool_flag, nargs='?', const=True, help='use cached data to accelerate data loading')

_StoreAction(option_strings=['--use_cache'], dest='use_cache', nargs='?', const=True, default=True, type=<function bool_flag at 0x7fa3b94b1ea0>, choices=None, required=False, help='use cached data to accelerate data loading', metavar=None)

In [26]:
args, _ = parser.parse_known_args()
args

Namespace(ent_emb='ddb', dataset='medqa_usmle', inhouse=True, inhouse_train_qids='data/csqa/inhouse_split_qids.txt', train_statements='data/medqa_usmle/statement/dev.statement.jsonl', dev_statements='data/medqa_usmle/statement/dev.statement.jsonl', test_statements='data/medqa_usmle/statement/test.statement.jsonl', max_seq_len=100, encoder='cambridgeltl/SapBERT-from-PubMedBERT-fulltext', encoder_layer=-1, encoder_lr=2e-05, loss='cross_entropy', optim='radam', lr_schedule='fixed', batch_size=32, warmup_steps=150, max_grad_norm=1.0, weight_decay=0.01, n_epochs=100, max_epochs_before_stop=10, log_interval=10, cuda=True, seed=0, debug=False, eval_batch_size=2, subsample=1.0, use_cache=True, ent_emb_paths=['data/ddb/ent_emb.npy'], mode='eval_detail', load_model_path='saved_models/medqa_usmle_model_hf3.4.0.pt', save_model=True, save_dir='saved_models', train_adj='data/medqa_usmle/graph/dev.graph.adj.pk', dev_adj='data/medqa_usmle/graph/dev.graph.adj.pk', test_adj='data/medqa_usmle/graph/test.

In [27]:
args.ent_emb_paths

['data/ddb/ent_emb.npy']

### Evaluate QAGNN
Mainly from qagnn.py -> evaluate

In [28]:
assert args.load_model_path is not None
cp_emb = [np.load(path) for path in args.ent_emb_paths]
cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float)
concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
print('| num_concepts: {} |'.format(concept_num))

| num_concepts: 9958 |


In [29]:
args.ent_emb_paths

['data/ddb/ent_emb.npy']

In [30]:
concept_dim

768

In [31]:
model_path = args.load_model_path
model_state_dict, old_args = torch.load(model_path, map_location=torch.device('cpu'))
model = LM_QAGNN(old_args, old_args.encoder, k=old_args.k, n_ntype=4, n_etype=old_args.num_relation, n_concept=concept_num,
                           concept_dim=old_args.gnn_dim,
                           concept_in_dim=concept_dim,
                           n_attention_head=old_args.att_head_num, fc_dim=old_args.fc_dim, n_fc_layer=old_args.fc_layer_num,
                           p_emb=old_args.dropouti, p_gnn=old_args.dropoutg, p_fc=old_args.dropoutf,
                           pretrained_concept_emb=cp_emb, freeze_ent_emb=old_args.freeze_ent_emb,
                           init_range=old_args.init_range,
                           encoder_config={})
model.load_state_dict(model_state_dict)

FileNotFoundError: [Errno 2] No such file or directory: 'saved_models/medqa_usmle_model_hf3.4.0.pt'

In [54]:
old_args

Namespace(att_head_num=2, batch_size=128, cuda=True, dataset='medqa_usmle', debug=False, decoder_lr=0.001, dev_adj='data/medqa_usmle/graph/dev.graph.adj.pk', dev_statements='data/medqa_usmle/statement/dev.statement.jsonl', dropoutf=0.2, dropoutg=0.2, dropouti=0.2, encoder='cambridgeltl/SapBERT-from-PubMedBERT-fulltext', encoder_layer=-1, encoder_lr=5e-05, ent_emb=['ddb'], ent_emb_paths=['data/ddb/ent_emb.npy'], eval_batch_size=2, fc_dim=200, fc_layer_num=0, fp16=True, freeze_ent_emb=True, gnn_dim=200, inhouse=False, inhouse_train_qids='data/medqa_usmle/inhouse_split_qids.txt', init_range=0.02, k=5, load_model_path=None, log_interval=10, loss='cross_entropy', lr_schedule='fixed', max_epochs_before_stop=10, max_grad_norm=1.0, max_node_num=200, max_seq_len=512, mini_batch_size=2, mode='train', n_epochs=15, num_relation=34, optim='radam', refreeze_epoch=10000, save_dir='saved_models/medqa_usmle/enc-sapbert__k5__gnndim200__bs128__seed1__20211106_120518', save_model=True, seed=1, simple=Fals

In [55]:
if torch.cuda.device_count() >= 2 and args.cuda:
    device0 = torch.device("cuda:0")
    device1 = torch.device("cuda:1")
elif torch.cuda.device_count() == 1 and args.cuda:
    device0 = torch.device("cuda:0")
    device1 = torch.device("cuda:0")
else:
    device0 = torch.device("cpu")
    device1 = torch.device("cpu")

In [56]:
model.encoder.to(device0)
model.decoder.to(device1)
model.eval()

LM_QAGNN(
  (encoder): TextEncoder(
    (module): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                

In [57]:
statement_dic = {}
for statement_path in (args.train_statements, args.dev_statements, args.test_statements):
    statement_dic.update(load_statement_dict(statement_path))

use_contextualized = 'lm' in old_args.ent_emb

print ('inhouse?', args.inhouse)

print ('args.train_statements', args.train_statements)
print ('args.dev_statements', args.dev_statements)
print ('args.test_statements', args.test_statements)
print ('args.train_adj', args.train_adj)
print ('args.dev_adj', args.dev_adj)
print ('args.test_adj', args.test_adj)


inhouse? True
args.train_statements data/medqa_usmle/statement/dev.statement.jsonl
args.dev_statements data/medqa_usmle/statement/dev.statement.jsonl
args.test_statements data/medqa_usmle/statement/test.statement.jsonl
args.train_adj data/medqa_usmle/graph/dev.graph.adj.pk
args.dev_adj data/medqa_usmle/graph/dev.graph.adj.pk
args.test_adj data/medqa_usmle/graph/test.graph.adj.pk


In [58]:
#from qagnn/qagnn.py
def evaluate_accuracy(eval_set, model):
    n_samples, n_correct = 0, 0
    model.eval()
    with torch.no_grad():
        for qids, labels, *input_data in tqdm(eval_set):
            logits, _ = model(*input_data)
            n_correct += (logits.argmax(1) == labels).sum().item()
            n_samples += labels.size(0)
    return n_correct / n_samples

In [59]:
dataset = LM_QAGNN_DataLoader(args, args.train_statements, args.train_adj,
                                       args.dev_statements, args.dev_adj,
                                       args.test_statements, args.test_adj,
                                       batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
                                       device=(device0, device1),
                                       model_name=old_args.encoder,
                                       max_node_num=old_args.max_node_num, max_seq_length=old_args.max_seq_len,
                                       is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
                                       subsample=args.subsample, use_cache=args.use_cache)

train_statement_path data/medqa_usmle/statement/dev.statement.jsonl


100%|██████████| 1272/1272 [00:12<00:00, 100.09it/s]
100%|██████████| 1272/1272 [00:12<00:00, 101.47it/s]


num_choice 4
| ori_adj_len: mu 25.20 sigma 28.18 | adj_len: 26.20 | prune_rate： 0.00 | qc_num: 4.34 | ac_num: 1.07 |
| ori_adj_len: mu 25.20 sigma 28.18 | adj_len: 26.20 | prune_rate： 0.00 | qc_num: 4.34 | ac_num: 1.07 |


100%|██████████| 1273/1273 [00:12<00:00, 98.73it/s] 


| ori_adj_len: mu 26.09 sigma 28.17 | adj_len: 27.06 | prune_rate： 0.00 | qc_num: 4.61 | ac_num: 1.06 |


#### Batch Mode Evaluate

In [60]:
save_test_preds = args.save_model
dev_acc = evaluate_accuracy(dataset.dev(), model)
print('dev_acc {:7.4f}'.format(dev_acc))

100%|██████████| 636/636 [02:52<00:00,  3.69it/s]

dev_acc  0.3789





In [61]:
if not save_test_preds:
    test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0
else:
    eval_set = dataset.test()
    total_acc = []
    count = 0
    dt = datetime.datetime.today().strftime('%Y%m%d%H%M%S')
    preds_path = os.path.join(args.save_dir, 'test_preds_{}.csv'.format(dt))
    with open(preds_path, 'w') as f_preds:
        with torch.no_grad():
            for qids, labels, *input_data in tqdm(eval_set):
                count += 1
                logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True)
                predictions = logits.argmax(1) #[bsize, ]
                preds_ranked = (-logits).argsort(1) #[bsize, n_choices]
                for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)):
                    acc = int(pred.item()==label.item())
                    print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds)
                    f_preds.flush()
                    total_acc.append(acc)
    test_acc = float(sum(total_acc))/len(total_acc)

    print('-' * 71)
    print('test_acc {:7.4f}'.format(test_acc))
    print('-' * 71)


 34%|███▍      | 216/636 [00:57<01:52,  3.73it/s]


KeyboardInterrupt: 

#### Sample Mode Evaluation

In [62]:
eval_set = dataset.test()
subsample = list(eval_set)[0:1]#random.sample(list(eval_set), 2)
total_acc = []
count = 0
dt = datetime.datetime.today().strftime('%Y%m%d%H%M%S')
preds_path = os.path.join(args.save_dir, 'test_preds_{}.csv'.format(dt))
with open(preds_path, 'w') as f_preds:
    with torch.no_grad():
        for qids, labels, *input_data in tqdm(subsample):
            count += 1
            logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True)
            print(qids, labels)
            print(concept_ids)
            predictions = logits.argmax(1) #[bsize, ]
            preds_ranked = (-logits).argsort(1) #[bsize, n_choices]
            for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)):
                acc = int(pred.item()==label.item())
                print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds)
                f_preds.flush()
                total_acc.append(acc)

100%|██████████| 1/1 [00:00<00:00,  3.86it/s]

['train-00000', 'train-00001'] tensor([2, 0], device='cuda:0')
tensor([[[   0,  553,  821,  ...,    1,    1,    1],
         [   0,  553,  821,  ...,    1,    1,    1],
         [   0,  553,  821,  ...,    1,    1,    1],
         [   0,  553,  821,  ...,    1,    1,    1]],

        [[   0,   57, 2089,  ...,    1,    1,    1],
         [   0,   57, 2089,  ...,    1,    1,    1],
         [   0,   57, 2089,  ...,    1,    1,    1],
         [   0,   57, 2089,  ...,    1,    1,    1]]], device='cuda:0')





In [35]:
len(subsample)

1

In [38]:
subsample[0]

(['train-00000', 'train-00001'],
 tensor([2, 0], device='cuda:0'),
 tensor([[[   2,   43, 2900,  ...,    0,    0,    0],
          [   2,   43, 2900,  ...,    0,    0,    0],
          [   2,   43, 2900,  ...,    0,    0,    0],
          [   2,   43, 2900,  ...,    0,    0,    0]],
 
         [[   2,   43,   25,  ...,    0,    0,    0],
          [   2,   43,   25,  ...,    0,    0,    0],
          [   2,   43,   25,  ...,    0,    0,    0],
          [   2,   43,   25,  ...,    0,    0,    0]]], device='cuda:0'),
 tensor([[[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]],
 
         [[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]]], device='cuda:0'),
 tensor([[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],
 
         [[0, 0, 0, 