# Initialization

## packages

In [1]:
import os
import jupyter_capture_output
working_path = os.getcwd()
os.chdir('../')

Jupyter Capture Output v0.0.11


In [2]:
import sys
import torch
import json
import random
import argparse
import collections
import torch.nn as nn
from uer.utils.vocab import Vocab
from uer.utils.constants import *
from uer.utils.tokenizer import * 
from uer.model_builder import build_model
from uer.utils.optimizers import  BertAdam
from uer.utils.config import load_hyperparam
from uer.utils.seed import set_seed
from uer.model_saver import save_model
from brain import KnowledgeGraph
from multiprocessing import Process, Pool
import numpy as np
import os

from transformers import AutoTokenizer
from inf_classifier import BertClassifier, add_argument_for_paser_of_BertClassifier

## set arguments

In [3]:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
add_argument_for_paser_of_BertClassifier(parser)

In [4]:
args=parser.parse_args([
  "--pretrained_model_path", "./models/google_model.bin" ,
    "--config_path", "./models/google_config.json", 
    "--vocab_path", "./models/google_vocab.txt", 
    "--train_path", "./datasets/CheXpert/impression/train.csv", 
    "--dev_path", "./datasets/CheXpert/impression/validation.csv",
    "--test_path", "./datasets/CheXpert/impression/test.csv",
    "--epochs_num", "3", 
    "--batch_size", "16", 
   #  "--tokenizer_from_huggingface", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    "--tokenizer_from_huggingface", "",
    "--kg_name", "./brain/kgs/CheXpert_KG.spo",
    "--output_model_path", "./outputs/infCheXbert_half_integrated_2/infCheXbert_half_integrated.bin"])


args = load_hyperparam(args)
args.learning_rate = 1e-5
set_seed(args.seed)

str2tokenizer = {"char": CharTokenizer, "space": SpaceTokenizer, "bert": BertTokenizer}
args.tokenizer = str2tokenizer[args.tokenizer](args)

In [5]:
if '.bin' not in args.output_model_path:
    print('output models should be saved as .bin file')
else:
    path = args.output_model_path.split('/')
    name_of_modelscompany = path[-1].replace('.bin','')
    path = [x + '/' for x in path]
    path.remove(path[-1])
    Parent_directory = ''.join(path)
    if not os.path.exists(Parent_directory):
        os.makedirs(Parent_directory)

## initialise global variables 

In [6]:
columns = {} # to check column locations of labels and text 
label_columns = {} # to check column location of each label
label_names = []
labels_sets= [] # a list of store sets of label values, in my case, label values are 0,1,2
with open(args.train_path, mode="r", encoding="utf-8") as f:
    for line_id, line in enumerate(f):
        try:
            line = line.strip().split("\t")
            if line_id == 0:
                for i, column_name in enumerate(line):
                    columns[column_name] = i
                    if 'label' in column_name:
                        label_columns[column_name] = i
                        label_names.append( column_name)
                        labels_sets.append(set())
                continue
            # count label numbers for each label name
            # in our case, labels_nums is known: 3 for each label name. following is just for generalization
            for i,label_set in enumerate(labels_sets):
                label = int(line[label_columns[label_names[i]]])
                label_set.add(label)
        except:
            pass



labels_nums = [len(labels_set) for labels_set in labels_sets] 

# Load vocabulary.
vocab = Vocab()
print('for english sentences, huggingface tokenizer or nltk tokenizer is recommanded')
if args.tokenizer_from_huggingface: 
    print('huggingface tokenizer detected, use tokenizer {args.tokenizer_from_huggingface}')
    tok_en = AutoTokenizer.from_pretrained(args.tokenizer_from_huggingface)
vocab.load(args.vocab_path)
args.vocab = vocab
args.target = "bert"

device = torch.device('mps' if torch.backends.mps.is_available() 
                      else "cuda" if torch.cuda.is_available() 
                      else "cpu")



if torch.backends.mps.is_available(): torch.mps.empty_cache()
# Build bert model.
models = []
output_model_path = []

for counter,label_num in enumerate(labels_nums):
    args.labels_num = label_num
    output_model_path.append(args.output_model_path.replace('.bin','') + '_' + label_names[counter] +'.bin')
    print(f'labeler {label_names[counter]} would be saved under path {output_model_path[counter]}')
    model = build_model(args)
    # Load or initialize parameters.
    if args.pretrained_model_path is not None:
        # Initialize with pretrained model.
        model.load_state_dict(torch.load(args.pretrained_model_path), strict=False)  
    else:
        # Initialize with normal distribution.
        for n, p in list(model.named_parameters()):
            if 'gamma' not in n and 'beta' not in n:
                p.data.normal_(0, 0.02)

    model = BertClassifier(args, model)

    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    model = model.to(device)
    
    models.append(model)

args.output_model_path = output_model_path


for english sentences, huggingface tokenizer or nltk tokenizer is recommanded
Vocabulary file line 344 has bad format token
Vocabulary Size:  21128
labeler label_Atelectasis would be saved under path ./outputs/infCheXbert_half_integrated_2/infCheXbert_half_integrated_label_Atelectasis.bin
[BertClassifier] use visible_matrix: True
labeler label_Cardiomegaly would be saved under path ./outputs/infCheXbert_half_integrated_2/infCheXbert_half_integrated_label_Cardiomegaly.bin
[BertClassifier] use visible_matrix: True
labeler label_Consolidation would be saved under path ./outputs/infCheXbert_half_integrated_2/infCheXbert_half_integrated_label_Consolidation.bin
[BertClassifier] use visible_matrix: True
labeler label_Edema would be saved under path ./outputs/infCheXbert_half_integrated_2/infCheXbert_half_integrated_label_Edema.bin
[BertClassifier] use visible_matrix: True
labeler label_Enlarged Cardiomediastinum would be saved under path ./outputs/infCheXbert_half_integrated_2/infCheXbert_hal

# Build knowledge graph.

In [7]:
if args.kg_name == 'none':
    spo_files = []
else:
    spo_files = [args.kg_name]
kg = KnowledgeGraph(spo_files=spo_files, tokenizer=args.tokenizer,predicate=True)

[KnowledgeGraph] Loading spo from ./brain/kgs/CheXpert_KG.spo


## define assist functions

In [8]:
def add_knowledge_worker(params):
    '''
    - input parameters are p_id, sentences, columns, label_columns, kg, vocab, args.seq_length
    - output is a dataset which is a list
        - structure of the dataset output:
            output[0]=token_ids, 
            output[1]=label, 
            output[2]=mask,
            output[3]=pos, 
            output[4]=vm 
        - in case of multiple classifier with multiple labels: output[1]=label is again a list
            where each entry has all label value for its label name
            for example: 
                output[1][0] == [1,0,2,3,4,...] for columns['label_Atelectasis]
                output[1][1] == [1,0,2,3,4,...] for columns['label_Cardiomegaly]
        

    '''

    p_id, sentences, columns, kg, vocab, args = params
    text_column = {}
    labels_columns = {}
    for k,v in columns.items():
        if 'label' in k:
            labels_columns.update({k:v})
        else:
            text_column.update({k:v})

    sentences_num = len(sentences)
    dataset = []

    labels_position = list(labels_columns.values())
    text_position = list(text_column.values())

    for line_id, line in enumerate(sentences):
        if line_id % 10000 == 0:
            print("Progress of process {}: {}/{}".format(p_id, line_id, sentences_num))
            sys.stdout.flush()
        line = line.strip().split('\t')
        try:
            if len(line) == 2:
                label = [int(line[columns["label"]])]
                text = CLS_TOKEN + line[columns["text_a"]]
   
                tokens, pos, vm, _ = kg.add_knowledge_with_vm([text], add_pad=True, max_length=args.seq_length)
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0].astype("bool")
                if args.tokenizer_from_huggingface:
                    token_ids = tok_en.convert_tokens_to_ids(tokens)
                else:
                    token_ids = [vocab.get(t) for t in tokens]
                    # token_ids = args.tokenizer.convert_tokens_to_ids(tokens)
                
                mask = [1 if t != PAD_TOKEN else 0 for t in tokens]

                dataset.append((token_ids, label, mask, pos, vm))
            
            elif (len(line) == 3) and ("text_b" in line):
                label = int(line[columns["label"]])
                text = CLS_TOKEN + line[columns["text_a"]] + SEP_TOKEN + line[columns["text_b"]] + SEP_TOKEN

                tokens, pos, vm, _ = kg.add_knowledge_with_vm([text], add_pad=True, max_length=args.seq_length)
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0].astype("bool")

                token_ids = [vocab.get(t) for t in tokens]
                mask = []
                seg_tag = 1
                for t in tokens:
                    if t == PAD_TOKEN:
                        mask.append(0)
                    else:
                        mask.append(seg_tag)
                    if t == SEP_TOKEN:
                        seg_tag += 1

                dataset.append((token_ids, label, mask, pos, vm))
            
            elif (len(line) == 4) and ('qid' in line):  # for dbqa
                qid=int(line[columns["qid"]])
                label = int(line[columns["label"]])
                text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]
                text = CLS_TOKEN + text_a + SEP_TOKEN + text_b + SEP_TOKEN

                tokens, pos, vm, _ = kg.add_knowledge_with_vm([text], add_pad=True, max_length=args.seq_length)
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0].astype("bool")

                token_ids = [vocab.get(t) for t in tokens]
                mask = []
                seg_tag = 1
                for t in tokens:
                    if t == PAD_TOKEN:
                        mask.append(0)
                    else:
                        mask.append(seg_tag)
                    if t == SEP_TOKEN:
                        seg_tag += 1
                
                dataset.append((token_ids, label, mask, pos, vm, qid))

            # multiple classification with multiple labels

            elif len(labels_columns.keys()) >=2 : 

                labels = [line[x] for x in labels_position]
                text = [line[x] for x in text_position]
                text = CLS_TOKEN + ' ' +text[0]

                tokens, pos, vm, _ = kg.add_knowledge_with_vm([text], add_pad=True, max_length=args.seq_length)
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0].astype("bool")
                # token_ids = [vocab.get(t) for t in tokens]
                token_ids = args.tokenizer.convert_tokens_to_ids(tokens)
                # token_ids = tok_en.convert_tokens_to_ids(tokens)
                mask = [1 if t != PAD_TOKEN else 0 for t in tokens]

                dataset.append((token_ids, labels, mask, pos, vm))
                
            
        except Exception as e:
            print("Error line: ", line)
            print(e)
    return dataset


In [9]:
# batch loader.

'''
def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms):
    instances_num = input_ids.size()[0]
    for i in range(instances_num // batch_size):
        input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
        labels_ids_batch = [label_id[i*batch_size: (i+1)*batch_size] for label_id in label_ids]
        mask_ids_batch = mask_ids[i*batch_size: (i+1)*batch_size, :]
        pos_ids_batch = pos_ids[i*batch_size: (i+1)*batch_size, :]
        vms_batch = vms[i*batch_size: (i+1)*batch_size]
        yield input_ids_batch, labels_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch
    if instances_num > instances_num // batch_size * batch_size:
        input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
        labels_ids_batch = [label_id[instances_num//batch_size*batch_size:] for label_id in label_ids]
        mask_ids_batch = mask_ids[instances_num//batch_size*batch_size:, :]
        pos_ids_batch = pos_ids[instances_num//batch_size*batch_size:, :]
        vms_batch = vms[instances_num//batch_size*batch_size:]

        yield input_ids_batch, labels_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch
'''

def multi_label_batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms):
    instances_num = input_ids.size()[0]
    for i in range(instances_num // batch_size):
        input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
        labels_ids_batch = [label_id[i*batch_size: (i+1)*batch_size] for label_id in label_ids]
        mask_ids_batch = mask_ids[i*batch_size: (i+1)*batch_size, :]
        pos_ids_batch = pos_ids[i*batch_size: (i+1)*batch_size, :]
        vms_batch = vms[i*batch_size: (i+1)*batch_size]
        yield input_ids_batch, labels_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch
    if instances_num > instances_num // batch_size * batch_size:
        input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
        labels_ids_batch = [label_id[instances_num//batch_size*batch_size:] for label_id in label_ids]
        mask_ids_batch = mask_ids[instances_num//batch_size*batch_size:, :]
        pos_ids_batch = pos_ids[instances_num//batch_size*batch_size:, :]
        vms_batch = vms[instances_num//batch_size*batch_size:]
        yield input_ids_batch, labels_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch

def batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms):
    instances_num = input_ids.size()[0]
    for i in range(instances_num // batch_size):
        input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
        label_ids_batch = label_ids[i*batch_size: (i+1)*batch_size]
        mask_ids_batch = mask_ids[i*batch_size: (i+1)*batch_size, :]
        pos_ids_batch = pos_ids[i*batch_size: (i+1)*batch_size, :]
        vms_batch = vms[i*batch_size: (i+1)*batch_size]
        yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch
    if instances_num > instances_num // batch_size * batch_size:
        input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
        label_ids_batch = label_ids[instances_num//batch_size*batch_size:]
        mask_ids_batch = mask_ids[instances_num//batch_size*batch_size:, :]
        pos_ids_batch = pos_ids[instances_num//batch_size*batch_size:, :]
        vms_batch = vms[instances_num//batch_size*batch_size:]
        yield input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch
# dataset loader
def read_dataset(path, workers_num=1):

    print("Loading sentences from {}".format(path))
    sentences = []
    with open(path, mode='r', encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                continue
            sentences.append(line)
    sentence_num = len(sentences)

    print("There are {} sentence in total. We use {} processes to inject knowledge into sentences.".format(sentence_num, workers_num))
    if workers_num > 1:
        params = []
        sentence_per_block = int(sentence_num / workers_num) + 1
        for i in range(workers_num):
            params.append((i, sentences[i*sentence_per_block: (i+1)*sentence_per_block], columns, kg, vocab, args))
        pool = Pool(workers_num)
        res = pool.map(add_knowledge_worker, params)
        pool.close()
        pool.join()
        dataset = [sample for block in res for sample in block]
    else:
        params = (0, sentences, columns, kg, vocab, args)
        dataset = add_knowledge_worker(params)

    return dataset

# Evaluation function.
def evaluate(args, is_test, metrics='Acc',label_id = -1):
    counter = label_id
    if is_test:
        dataset = read_dataset(args.test_path, workers_num=args.workers_num)
    else:
        dataset = read_dataset(args.dev_path, workers_num=args.workers_num)
    labels_ids=[]
    input_ids = torch.LongTensor([sample[0] for sample in dataset])
    if label_id == -1:
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
    for nr, label in enumerate(label_columns):
        labels_ids.append(torch.LongTensor([int(example[1][nr]) for example in dataset]))
    mask_ids = torch.LongTensor([sample[2] for sample in dataset])
    pos_ids = torch.LongTensor([example[3] for example in dataset])
    vms = [example[4] for example in dataset]
    label_ids=labels_ids[counter]
    batch_size = args.batch_size
    instances_num = input_ids.size()[0]
    if is_test:
        print("The number of evaluation instances: ", instances_num)

    correct = 0
    # Confusion matrix.

    confusions = [torch.zeros(num, num, dtype=torch.long) for num in labels_nums]
    confusion = confusions[counter]
    model = models[counter]
    label_name =label_names[counter]
    label_ids = labels_ids[counter]

    model.eval()
        
    if not args.mean_reciprocal_rank:
        for i, (input_ids_batch, label_ids_batch,  mask_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms)):

            # vms_batch = vms_batch.long()
            vms_batch = torch.LongTensor(np.array(vms_batch))

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            vms_batch = vms_batch.to(device)

            with torch.inference_mode():
                try:
                    loss, logits = model(input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch)
                except:
                    print(input_ids_batch)
                    print(input_ids_batch.size())
                    print(vms_batch)
                    print(vms_batch.size())

            logits = nn.Softmax(dim=1)(logits)
            pred = torch.argmax(logits, dim=1).to(device)
            gold = label_ids_batch

            for j in range(pred.size()[0]):
                    confusion[pred[j], gold[j]] += 1
            correct += torch.sum(pred == gold).item()
    
        if is_test:
            print("Confusion matrix of {label_name}:")
            print(confusion)
            print("Report precision, recall, and f1:")
        
        for i in range(confusion.size()[0]):

            if (confusion[i,:].sum().item() == 0) and (confusion[:,i].sum().item()!= 0):
                print(f'model never predicts label value {i}')


            elif (confusion[:,i].sum().item()== 0) and (confusion[i,:].sum().item() != 0):
                print(f'dataset has no label value {i}')
            eps = 1e-9
            p = confusion[i,i].item()/(confusion[i,:].sum().item() + eps)
            r = confusion[i,i].item()/(confusion[:,i].sum().item() + eps)
            f1 = 2*p*r / (p + r + eps)

            if i == 1:
                label_1_f1 = f1
            print("labelsname: {}, Label_value {}: {:.3f}, {:.3f}, {:.3f}".format(label_name,i,p,r,f1))
        print("Acc. (Correct/Total): {:.4f} ({}/{}) ".format(correct/len(dataset), correct, len(dataset)))
        if metrics == 'Acc':
            return correct/len(dataset)
        elif metrics == 'f1':
            return label_1_f1
        else:
            return correct/len(dataset)
    else:
            for i, (input_ids_batch, label_ids_batch,  mask_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms)):

                vms_batch = torch.LongTensor(np.array(vms_batch))

                input_ids_batch = input_ids_batch.to(device)
                label_ids_batch = label_ids_batch.to(device)
                mask_ids_batch = mask_ids_batch.to(device)
                pos_ids_batch = pos_ids_batch.to(device)
                vms_batch = vms_batch.to(device)

                with torch.no_grad():
                    loss, logits = model(input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch)
                logits = nn.Softmax(dim=1)(logits)
                if i == 0:
                    logits_all=logits
                if i >= 1:
                    logits_all=torch.cat((logits_all,logits),0)
        
            order = -1
            gold = []
            for i in range(len(dataset)):
                qid = dataset[i][-1]
                label = dataset[i][1]
                if qid == order:
                    j += 1
                    if label == 1:
                        gold.append((qid,j))
                else:
                    order = qid
                    j = 0
                    if label == 1:
                        gold.append((qid,j))

            label_order = []
            order = -1
            for i in range(len(gold)):
                if gold[i][0] == order:
                    templist.append(gold[i][1])
                elif gold[i][0] != order:
                    order=gold[i][0]
                    if i > 0:
                        label_order.append(templist)
                    templist = []
                    templist.append(gold[i][1])
            label_order.append(templist)

            order = -1
            score_list = []
            for i in range(len(logits_all)):
                score = float(logits_all[i][1])
                qid=int(dataset[i][-1])
                if qid == order:
                    templist.append(score)
                else:
                    order = qid
                    if i > 0:
                        score_list.append(templist)
                    templist = []
                    templist.append(score)
            score_list.append(templist)

            rank = []
            pred = []
            print(len(score_list))
            print(len(label_order))
            for i in range(len(score_list)):
                if len(label_order[i])==1:
                    if label_order[i][0] < len(score_list[i]):
                        true_score = score_list[i][label_order[i][0]]
                        score_list[i].sort(reverse=True)
                        for j in range(len(score_list[i])):
                            if score_list[i][j] == true_score:
                                rank.append(1 / (j + 1))
                    else:
                        rank.append(0)

                else:
                    true_rank = len(score_list[i])
                    for k in range(len(label_order[i])):
                        if label_order[i][k] < len(score_list[i]):
                            true_score = score_list[i][label_order[i][k]]
                            temp = sorted(score_list[i],reverse=True)
                            for j in range(len(temp)):
                                if temp[j] == true_score:
                                    if j < true_rank:
                                        true_rank = j
                    if true_rank < len(score_list[i]):
                        rank.append(1 / (true_rank + 1))
                    else:
                        rank.append(0)
            MRR = sum(rank) / len(rank)
            print("MRR", MRR)
            return MRR


# Preprocess

## load datasets, integrate KG into datasets and convert to tensors

In [10]:

# Training phase.
print("load training dataset")
trainset = read_dataset(args.train_path, workers_num=args.workers_num)
print("Shuffling dataset")
random.shuffle(trainset)
instances_num = len(trainset)
batch_size = args.batch_size

print("Transfer data to tensor, which includes: ")
print("input_ids")
input_ids = torch.LongTensor([example[0] for example in trainset])
print("label_ids")
labels_ids = []

for counter in range(len(label_columns)):
    labels_ids.append(torch.LongTensor([int(example[1][counter]) for example in trainset]))

'''
for counter, label in enumerate(label_columns):
    # labels_ids.append(torch.LongTensor([int(example[1][label_columns[label]]) for example in trainset]))
    labels_ids.append(torch.LongTensor([int(example[1][counter]) for example in trainset]))
'''
print("mask_ids")
mask_ids = torch.LongTensor([example[2] for example in trainset])
print("pos_ids")
pos_ids = torch.LongTensor([example[3] for example in trainset])
print("vms")
vms = [example[4] for example in trainset]

load training dataset
Loading sentences from ./datasets/CheXpert/impression/train.csv
There are 99371 sentence in total. We use 1 processes to inject knowledge into sentences.
Progress of process 0: 0/99371
Progress of process 0: 10000/99371
Progress of process 0: 20000/99371
Progress of process 0: 30000/99371


KeyboardInterrupt: 

# train and evaluate

## modelscompany for 3 observations

In [11]:
models_3_obs = models[:3]

In [12]:
train_steps = int(instances_num * args.epochs_num / batch_size) + 1

print("Batch size: ", batch_size)
print("The number of training instances:", instances_num)



for counter, model in enumerate(models_3_obs):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]
    optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup, t_total=train_steps)


    total_loss = 0.
    result = 0.0
    best_result = 0.0
    label_name = label_names[counter]
    label_ids = labels_ids[counter]
    for epoch in range(1, args.epochs_num+1):

        model.train()


        for i, (input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms)):
            model.zero_grad()

            vms_batch = torch.LongTensor(np.array(vms_batch))

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            vms_batch = vms_batch.to(device)

            loss, _ = model(input_ids_batch, label_ids_batch, mask_ids_batch, pos=pos_ids_batch, vm=vms_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("labelsname: {}, Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(label_name,epoch, i+1, (total_loss / args.report_steps)))
                sys.stdout.flush()
                total_loss = 0.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print("Start evaluation on test dataset.")         
        result = evaluate(args, False,label_id=counter)
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path[counter])
        else:
            continue

        print("Start evaluation on evaluation dataset.")
        evaluate(args, True,label_id=counter)

    # Evaluation phase.
    print("Final evaluation on the evaluation dataset.")

    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(torch.load(args.output_model_path[counter]))
    else:
        model.load_state_dict(torch.load(args.output_model_path[counter]))
    evaluate(args, True,label_id=counter)

Batch size:  16
The number of training instances: 99371


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1485.)
  next_m.mul_(beta1).add_(1 - beta1, grad)


labelsname: label_Atelectasis, Epoch id: 1, Training steps: 100, Avg loss: 0.968
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 200, Avg loss: 0.611
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 300, Avg loss: 0.298
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 400, Avg loss: 0.280
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 500, Avg loss: 0.279
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 600, Avg loss: 0.226
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 700, Avg loss: 0.248
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 800, Avg loss: 0.186
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 900, Avg loss: 0.151
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 1000, Avg loss: 0.163
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 1100, Avg loss: 0.136
labelsname: label_Atelectasis, Epoch id: 1, Training steps: 1200, Avg loss: 0.182
labelsname: label_Atelect

## continue training the complete models company

In [13]:
'''
models_rest = models[3:]

train_steps = int(instances_num * args.epochs_num / batch_size) + 1

print("Batch size: ", batch_size)
print("The number of training instances:", instances_num)

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
]
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup, t_total=train_steps)


for counter, model in enumerate(models_3_obs):
    total_loss = 0.
    result = 0.0
    best_result = 0.0
    label_name = label_names[counter]
    label_ids = labels_ids[counter]
    for epoch in range(1, args.epochs_num+1):

        model.train()


        for i, (input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms)):
            model.zero_grad()

            vms_batch = torch.LongTensor(np.array(vms_batch))

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            mask_ids_batch = mask_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            vms_batch = vms_batch.to(device)

            loss, _ = model(input_ids_batch, label_ids_batch, mask_ids_batch, pos=pos_ids_batch, vm=vms_batch)
            if torch.cuda.device_count() > 1:
                loss = torch.mean(loss)
            total_loss += loss.item()
            if (i + 1) % args.report_steps == 0:
                print("labelsname: {}, Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(label_name,epoch, i+1, (total_loss / args.report_steps)))
                sys.stdout.flush()
                total_loss = 0.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print("Start evaluation on test dataset.")         
        result = evaluate(args, False,label_id=counter)
        if result > best_result:
            best_result = result
            save_model(model, args.output_model_path[counter])
        else:
            continue

        print("Start evaluation on evaluation dataset.")
        evaluate(args, True,label_id=counter)

    # Evaluation phase.
    print("Final evaluation on the evaluation dataset.")

    if torch.cuda.device_count() > 1:
        model.module.load_state_dict(torch.load(args.output_model_path[counter]))
    else:
        model.load_state_dict(torch.load(args.output_model_path[counter]))
    evaluate(args, True,label_id=counter)
'''

'\nmodels_rest = models[3:]\n\ntrain_steps = int(instances_num * args.epochs_num / batch_size) + 1\n\nprint("Batch size: ", batch_size)\nprint("The number of training instances:", instances_num)\n\nparam_optimizer = list(model.named_parameters())\nno_decay = [\'bias\', \'gamma\', \'beta\']\noptimizer_grouped_parameters = [\n            {\'params\': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], \'weight_decay_rate\': 0.01},\n            {\'params\': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], \'weight_decay_rate\': 0.0}\n]\noptimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup, t_total=train_steps)\n\n\nfor counter, model in enumerate(models_3_obs):\n    total_loss = 0.\n    result = 0.0\n    best_result = 0.0\n    label_name = label_names[counter]\n    label_ids = labels_ids[counter]\n    for epoch in range(1, args.epochs_num+1):\n\n        model.train()\n\n\n        for i, (input_ids_batch, label_

# wrong prediction analysis

In [11]:
def report_wrong_predic(args, is_test, metrics='Acc',label_id = -1):
    counter = label_id
    if is_test:
        dataset = read_dataset(args.test_path, workers_num=args.workers_num)
    else:
        dataset = read_dataset(args.dev_path, workers_num=args.workers_num)
    labels_ids=[]
    input_ids = torch.LongTensor([sample[0] for sample in dataset])
    if label_id == -1:
        label_ids = torch.LongTensor([sample[1] for sample in dataset])
    for nr, label in enumerate(label_columns):
        labels_ids.append(torch.LongTensor([int(example[1][nr]) for example in dataset]))
    mask_ids = torch.LongTensor([sample[2] for sample in dataset])
    pos_ids = torch.LongTensor([example[3] for example in dataset])
    vms = [example[4] for example in dataset]
    label_ids=labels_ids[counter]
    batch_size = args.batch_size
    instances_num = input_ids.size()[0]
    if is_test:
        print("The number of evaluation instances: ", instances_num)


    model = models[counter]
    label_name =label_names[counter]
    label_ids = labels_ids[counter]

    model.eval()


    for i, (input_ids_batch, label_ids_batch,  mask_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids, pos_ids, vms)):

          # vms_batch = vms_batch.long()
          vms_batch = torch.LongTensor(np.array(vms_batch))

          input_ids_batch = input_ids_batch.to(device)
          label_ids_batch = label_ids_batch.to(device)
          mask_ids_batch = mask_ids_batch.to(device)
          pos_ids_batch = pos_ids_batch.to(device)
          vms_batch = vms_batch.to(device)

          with torch.inference_mode():
              try:
                  loss, logits = model(input_ids_batch, label_ids_batch, mask_ids_batch, pos_ids_batch, vms_batch)
              except:
                  print(input_ids_batch)
                  print(input_ids_batch.size())
                  print(vms_batch)
                  print(vms_batch.size())

          logits = nn.Softmax(dim=1)(logits)
          pred = torch.argmax(logits, dim=1).to(device)
          gold = label_ids_batch

          for j in range(pred.size()[0]):
                  if pred[j] != gold[j]:
                    print(' '.join([x for x in args.tokenizer.convert_ids_to_tokens(input_ids_batch[j].tolist()) if x not in ['[PAD]','[CLS]']]))
                    print(f'pred: {pred.tolist()[j]}. correct value: {gold.tolist()[j]}')

In [12]:
label_columns

{'label_Atelectasis': 1,
 'label_Cardiomegaly': 2,
 'label_Consolidation': 3,
 'label_Edema': 4,
 'label_Enlarged Cardiomediastinum': 5,
 'label_Fracture': 6,
 'label_Lung Lesion': 7,
 'label_Lung Opacity': 8,
 'label_No Finding': 9,
 'label_Pleural Effusion': 10,
 'label_Pleural Other': 11,
 'label_Pneumonia': 12,
 'label_Pneumothorax': 13}

choose observation you wanted to check from above list and assign it into ```tested_label_name``` below

In [15]:
tested_label_name = 'label_Atelectasis'

In [16]:
counter = label_columns[tested_label_name]
model.load_state_dict(torch.load(args.output_model_path[counter]))
report_wrong_predic(args, True,label_id=counter)

Loading sentences from ./datasets/CheXpert/impression/test.csv
There are 33124 sentence in total. We use 1 processes to inject knowledge into sentences.
Progress of process 0: 0/33124
Progress of process 0: 10000/33124
Progress of process 0: 20000/33124
Progress of process 0: 30000/33124
The number of evaluation instances:  33124
com ##pa ##ris ##on to _ _ _ . the pa ##ti ##ent has re ##ce ##ive ##d a new right ex ##ter ##na ##l pa ##ce ##ma ##ker . the course of the de ##vice is u ##n ##re ##mark ##able , the ti ##p project ##s over the right ve ##nt ##ric ##le . no com ##pl ##ica ##tions , not ##ab ##ly no p ##ne ##um ##ot ##ho ##ra ##x .
pred: [3, 1, 3, 0, 0, 3, 1, 2, 2, 0, 1, 0, 2, 3, 0, 1]. correct value: [0, 0, 3, 0, 1, 0, 0, 2, 0, 1, 0, 0, 0, 0, 2, 0]
cl ##ini ##cal set ##ting . dr . _ _ _ _ _ _ dr . _ _ _ at 7 : 21 ##pm on _ _ _ .
pred: [3, 1, 3, 0, 0, 3, 1, 2, 2, 0, 1, 0, 2, 3, 0, 1]. correct value: [0, 0, 3, 0, 1, 0, 0, 2, 0, 1, 0, 0, 0, 0, 2, 0]
ch ##ron ##ic pu ##lm ##ona #

KeyboardInterrupt: 