In [8]:
import torch
import json
from tqdm import tqdm 
from torch.utils.data import TensorDataset
import faiss
import numpy as np 
from transformers import BertTokenizer, BertModel
import pickle
from process_data import load_data_split, get_candidate_representation

In [1]:
def load_dictionary(dictionary_path): 
    ids = []
    names = []
    with open(dictionary_path,'r') as f: 
        for line in f: 
            line = line.strip().split('||') 
            ids.append(line[0])
            names.append(line[1])
    return ids,names

In [3]:
biosyn_path = "/share/project/biomed/hcd/BioSyn/pretrained/biosyn-sapbert-bc5cdr-chemical"
train_data = load_data_split("./data/bc5cdr-c_v1/processed/train.jsonl")
val_data = load_data_split("./data/bc5cdr-c_v1/processed/val.jsonl")
test_data = load_data_split("./data/bc5cdr-c_v1/processed/test.jsonl")
#entities = load_entities("./data/bc5cdr-c_v1/entity_documents.json")
path = "/share/project/biomed/hcd/BioSyn/datasets/bc5cdr-chemical/train_dictionary.txt"
train_dict_ids, train_dict_names = load_dictionary(path)
# train_dict_ids = train_dict_ids[:100]
# train_dict_names = train_dict_names[:100]
path = "/share/project/biomed/hcd/BioSyn/datasets/bc5cdr-chemical/dev_dictionary.txt"
val_dict_ids, val_dict_names = load_dictionary(path)
# val_dict_ids = val_dict_ids[:100]
# val_dict_names = val_dict_names[:100]
path = "/share/project/biomed/hcd/BioSyn/datasets/bc5cdr-chemical/test_dictionary.txt"
test_dict_ids, test_dict_names = load_dictionary(path)
# test_dict_ids = test_dict_ids[:100]
# test_dict_names = test_dict_names[:100]
biosyn = BertModel.from_pretrained(biosyn_path)

biosyn_tokenizer = BertTokenizer.from_pretrained(biosyn_path)

biobert_tokenizer = BertTokenizer.from_pretrained("/share/project/biomed/hcd/arboEL/models/biobert-base-cased-v1.1")
device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )

biosyn = biosyn.to(device)

In [4]:

def process_entity_dictionary(ids, names, tokenizer, biobert_tokenizer, dictionary_processed = True):
    ## tokenizer has to be biosyn tokenizer
    idx_to_id = {}
    id_to_idx = {}
    entity_dictionary = []
    with torch.no_grad():
        for idx, (id_, name) in enumerate(zip(ids,names)):
            print(idx,end='\r')
            #id_to_idx[id_] = idx
            idx_to_id[idx] = id_
            if id_ not in id_to_idx:
                id_to_idx[id_] = []
            id_to_idx[id_].append(idx)
            if not dictionary_processed:
                #label_representation = get_candidate_representation(name.lower(), tokenizer)
                biobert_representation = get_candidate_representation(name.lower(),biobert_tokenizer)
                entity_dictionary.append({
                    # "tokens": label_representation["tokens"],
                    # "ids": label_representation["ids"],
                    "biobert_tokens": biobert_representation["tokens"],
                    "biobert_ids": biobert_representation["ids"]
                }) 
                
                
    return id_to_idx, idx_to_id, entity_dictionary
            

In [20]:
train_id2idx, train_idx2id, train_dictionary = process_entity_dictionary(train_dict_ids, train_dict_names, biosyn_tokenizer, biobert_tokenizer)
# with open("./data/bc5cdr-c_v1/processed/train/idx2id.pkl",'wb') as handle:
#     pickle.dump(train_idx2id, handle)
    
# with open("./data/bc5cdr-c_v1/processed/train/id2idx.pkl",'wb') as handle:
#     pickle.dump(train_id2idx, handle)

with open("./data/bc5cdr-c_v1/processed/train/biobert_dict.pkl", 'wb') as write_handle:
                    pickle.dump(train_dictionary, write_handle,
                                protocol=pickle.HIGHEST_PROTOCOL)
val_id2idx, val_idx2id, val_dictionary = process_entity_dictionary(val_dict_ids, val_dict_names, biosyn_tokenizer, biobert_tokenizer)
# with open("./data/bc5cdr-c_v1/processed/val/idx2id.pkl",'wb') as handle:
#     pickle.dump(val_idx2id, handle)
# with open("./data/bc5cdr-c_v1/processed/val/id2idx.pkl",'wb') as handle:
#     pickle.dump(val_id2idx, handle)

with open("./data/bc5cdr-c_v1/processed/val/biobert_dict.pkl", 'wb') as write_handle:
                    pickle.dump(val_dictionary, write_handle,
                                protocol=pickle.HIGHEST_PROTOCOL)
                    
test_id2idx, test_idx2id, test_dictionary = process_entity_dictionary(test_dict_ids, test_dict_names, biosyn_tokenizer, biobert_tokenizer)
# with open("./data/bc5cdr-c_v1/processed/test/idx2id.pkl",'wb') as handle:
#     pickle.dump(test_idx2id, handle)
# with open("./data/bc5cdr-c_v1/processed/test/id2idx.pkl",'wb') as handle:
#     pickle.dump(test_id2idx, handle)
with open("./data/bc5cdr-c_v1/processed/test/biobert_dict.pkl", 'wb') as write_handle:
                    pickle.dump(test_dictionary, write_handle,
                                protocol=pickle.HIGHEST_PROTOCOL)

407599

In [11]:
print(len(test_idx2id),len(val_idx2id), len(train_idx2id))

407600 407454 407247


In [104]:
def embed_dictionary(model, device,  entity_dictionary):
    ent_embs = [] 
    with torch.no_grad():
        for idx, ent in enumerate(tqdm(entity_dictionary[:100], desc = "calculating embeddings")):
            input = torch.tensor(ent['ids'])
            #print(input.shape)
            emb = model(input[None,:].to(device))[0].mean(1).squeeze(0)
            #print(emb.shape)
            ent_embs.append(emb.cpu())
    ent_embs = torch.stack(ent_embs)
    
    return ent_embs

train_ent_embs = embed_dictionary(biosyn, device, train_dictionary)
torch.save(train_ent_embs, './data/bc5cdr-c_v1/processed/train/dictionary_embs.pt')
val_ent_embs = embed_dictionary(biosyn, device, val_dictionary)
torch.save(val_ent_embs, './data/bc5cdr-c_v1/processed/val/dictionary_embs.pt')

test_ent_embs = embed_dictionary(biosyn, device, test_dictionary)
torch.save(test_ent_embs, './data/bc5cdr-c_v1/processed/test/dictionary_embs.pt')


calculating embeddings: 100%|██████████| 100/100 [00:00<00:00, 144.90it/s]
calculating embeddings: 100%|██████████| 100/100 [00:00<00:00, 145.40it/s]
calculating embeddings: 100%|██████████| 100/100 [00:00<00:00, 146.44it/s]


In [19]:
def get_candidates(mention_samples, model, device, entity_embs, id2idx, idx2id, n_candidates, tokenizer, split, debug=False):
    ### id2idx is the one we got from process_entity_dictionary. e.g. D1234 -> idx
    ### here entity_dictionary is entity_documents.json
    ### here mention_samples is e.g. train.jsonl
    id2emb = {}
    d = 768
    res = faiss.StandardGpuResources()
    index_flat = faiss.IndexFlatL2(d)
    gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat)
    gpu_index_flat.add(entity_embs)
    count = 0
    targets = []
    neighbor_list = []
    #embs = torch.load('')
    for idx, sample in enumerate(tqdm(mention_samples, desc = "calculating embeddings")):
            with torch.no_grad():
                mention_representation = get_candidate_representation(sample['mention'].lower(), tokenizer)
                mention_ids = mention_representation['ids']
                input = torch.tensor(mention_ids)
                emb = model(input[None,:].to(device))[0].mean(1)
            #emb = entity_embs[idx]
            #print(sample.keys())
            mention_label = sample['label_id']
            
                
            #print(mention_label_idx)
            #print(sample)
            emb = emb.cpu()
            #print(emb.shape)
            #idx2id = {v: k for k, v in id2idx.items()}
            id2emb[sample['mention_id']] = emb
            #top_k = n_candidates
            
            D, I = gpu_index_flat.search(emb, n_candidates)
            result = I[0]
            neighbor_ids = [idx2id[k] for k in result]
            if split == "train":
                if mention_label not in neighbor_ids:
                    if mention_label in id2idx:
                        mention_label_idx = id2idx[mention_label][0]
                        np.concatenate([[mention_label_idx], result])
                    
            neighbor_ids = [idx2id[k] for k in result]
            if mention_label in neighbor_ids:
                count +=1
            target = []
            
            for id_ in neighbor_ids:
                if id_ == mention_label:
                    target.append(1)
                else:
                    target.append(0)
            
            targets.append(target[0:n_candidates])
            neighbor_list.append(result[0:n_candidates])
       
    print(count/len(mention_samples))
    return neighbor_list, targets
n_candidates = 64
train_ent_embs = torch.load("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/train/dictionary_embs.pt")
with open("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/train/idx2id.pkl",'rb') as f:
    train_idx2id = pickle.load(f)
with open("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/train/id2idx.pkl",'rb') as f:
    train_id2idx = pickle.load(f)
# with open("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/train_dict.pkl",'rb') as f:
#     train_dict= pickle.load(f)
train_neighbors, train_labels = get_candidates(train_data, biosyn, device, train_ent_embs, train_id2idx, train_idx2id, n_candidates, biosyn_tokenizer, 'train')
np.save('./data/bc5cdr-c_v1/processed/train/neighbors.npy', train_neighbors)
np.save('./data/bc5cdr-c_v1/processed/train/neighbor_labels.npy', train_labels)

val_ent_embs = torch.load("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/val/dictionary_embs.pt")
with open("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/val/idx2id.pkl",'rb') as f:
    val_idx2id = pickle.load(f)
with open("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/val/id2idx.pkl",'rb') as f:
    val_id2idx = pickle.load(f)
# with open("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/val/dict.pkl",'rb') as f:
#     val_dict= pickle.load(f)
val_neighbors, val_labels = get_candidates(val_data, biosyn, device, val_ent_embs, val_id2idx, val_idx2id, n_candidates, biosyn_tokenizer, 'val')
np.save('./data/bc5cdr-c_v1/processed/val/neighbors.npy', val_neighbors)
np.save('./data/bc5cdr-c_v1/processed/val/neighbor_labels.npy', val_labels)

test_ent_embs = torch.load("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/test/dictionary_embs.pt")
with open("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/test/idx2id.pkl",'rb') as f:
    test_idx2id = pickle.load(f)
with open("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/test/id2idx.pkl",'rb') as f:
    test_id2idx = pickle.load(f)
# with open("/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/test/dict.pkl",'rb') as f:
#     train_dict= pickle.load(f)
test_neighbors, test_labels = get_candidates(test_data, biosyn, device, test_ent_embs, test_id2idx, test_idx2id, n_candidates, biosyn_tokenizer, 'test')
np.save('./data/bc5cdr-c_v1/processed/test/neighbors.npy', test_neighbors)
np.save('./data/bc5cdr-c_v1/processed/test/neighbor_labels.npy', test_labels)


calculating embeddings: 100%|██████████| 5157/5157 [00:52<00:00, 97.85it/s] 


0.8413806476633702


calculating embeddings: 100%|██████████| 5302/5302 [00:53<00:00, 99.56it/s] 


0.8647680120709166


calculating embeddings: 100%|██████████| 5351/5351 [00:53<00:00, 99.76it/s] 


0.8346103532050084


In [8]:
test_neighbors[0]

array([209117, 209115, 297590, 209109, 209120, 385224, 384097, 209098,
       222391, 396693, 269771, 209100, 170623, 169043, 240858,  85215,
       270275, 131578, 336778, 121046, 276018, 367136, 338993, 337386,
       237454, 367056, 235127, 223937, 364547, 169042, 268553, 192206,
       390510, 336806, 255057, 189925, 263709, 371793, 231227, 398031,
       249123, 223609, 396111, 188149, 246124, 384096, 207327, 209136,
       328969, 202296,   5139, 137546, 113842, 269797, 128230, 382096,
       160194, 125115, 340852, 222794, 237453, 226018, 367775, 308691])

In [4]:
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
#import numpy as np
from biencoder import BiEncoderRanker
#from process_data import load_data_split
import pickle
import argparse
import torch.nn.functional as F
import logging
from transformers import BertModel, BertTokenizer
from torch.utils.tensorboard import SummaryWriter

In [6]:
def setup_logger(name, log_file, level=logging.INFO):
    """To setup as many loggers as you want"""
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')

    handler = logging.FileHandler(log_file, mode='a')        
    handler.setFormatter(formatter)

    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.addHandler(handler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(formatter)

    logger.addHandler(consoleHandler)
    return logger

def calculate_accuracy(max_idxs, correct):
    #acc = 0
    for idx in max_idxs:
        if idx == 0:
           correct+=1
    return correct

def evaluate(reranker, val_dataloader, criterion, entities, neighbors, labels, device, logger):
    reranker.model.eval()
    total_samples = 0
    correct = 0
    loss_list = []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(val_dataloader,desc="validation mini batches")):
                torch.cuda.empty_cache()
            
                batch = tuple(t.to(device) for t in batch)
                mention_idxs, context_ids = batch    
                candidate_ids = []
                
                for idx in mention_idxs:
                    candidate_ids.append([])
                    for neighbor in neighbors[idx]:
                        candidate_ids[-1].append(entities[neighbor]['biobert_ids'])
                
                candidate_ids = torch.tensor(candidate_ids).to(device)
                #print(candidate_ids.shape)
                scores = reranker(context_ids, candidate_ids, device) 
                total_samples += scores.shape[0]
                
                labels = torch.FloatTensor(labels,device)
                max_idxs = scores.argmax(dim=1)
                correct = calculate_accuracy(max_idxs, correct)
                probs = F.softmax(scores, dim=1)

                #p_correct = probs[:, 0]
                
                loss = criterion(probs, labels)
                
                loss_list.append(loss.item())

    logger.info("Evaluation completed, with total samples: {}".format(total_samples))           
    return sum(loss_list)/len(loss_list), correct/total_samples


In [10]:
def load_pickle(path):
    with open(path,'rb') as handle:
        pkl = pickle.load(handle)
    return pkl

In [4]:
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
import numpy as np
from biencoder import BiEncoderRanker
import pickle
import argparse
import torch.nn.functional as F
from process_data import get_context_representation
import logging
from transformers import BertModel, BertTokenizer
from torch.utils.tensorboard import SummaryWriter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def select_field(data, key1, key2=None):
    if key2 is None:
        return [example[key1] for example in data]
    else:
        return [example[key1][key2] for example in data]
def process_mention_data(samples, tokenizer, id_to_idx, debug = False, max_length=512):
    processed_samples = []
    print("len of id2idx", len(id_to_idx))
    if debug:
        print("reducing sample size")
        samples = samples[:300]
    not_added = 0
    #mention_idxs = []
    for idx, sample in enumerate(tqdm(samples, desc = "Tokenizing mentions")):
        context = get_context_representation(sample, tokenizer)
        label_id = sample['label_id']
        
        try:
            record = {
                "mention": sample['mention'],
                "context_tokens": context['tokens'],
                "context_ids": context['ids'],
                "mention_idx": idx,
                "mention_id":sample['mention_id'],
                #"label_title": sample['label_title'],
                "label_title": sample['label'],
                "label_id":sample['label_id'],
                "label_idxs": id_to_idx[sample['label_id']]
                
            }
            processed_samples.append(record)
        except:
            not_added +=1
        
        
        #print(processed_samples)
    print("not added, due to inconsistency:",not_added)
    context_tensors = torch.tensor(
        select_field(processed_samples, "context_ids"), dtype=torch.long
    )
    print(context_tensors.shape)
    #print("processed_samples label idxs",processed_samples['label_idxs'].shape)
    mentiond_idxs = torch.tensor(
        select_field(processed_samples, "mention_idx"), dtype=torch.long
    )
    
    # label_idxs = torch.tensor(
    # select_field(processed_samples, "label_idxs"), dtype=torch.long,
    # )

    tensor_data = TensorDataset(mentiond_idxs, context_tensors)

    #return processed_samples, tensor_data
    return tensor_data


In [12]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
def main(params):
    
    writer = SummaryWriter()
    logger = setup_logger('biencoder_logger', './logs/biencoder64.log')
    reranker = BiEncoderRanker(params)
    
    model = reranker.model
    optimizer = reranker.optimizer
    tokenizer = reranker.tokenizer
    device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
    #reranker = reranker.to(device)
    #model = model.to(device)
    n_gpu = reranker.n_gpu
    #biosyn_path = "/share/project/biomed/hcd/BioSyn/pretrained/biosyn-sapbert-bc5cdr-chemical"
    train_data = load_data_split("./data/bc5cdr-c_v1/processed/train.jsonl")
    val_data = load_data_split("./data/bc5cdr-c_v1/processed/val.jsonl")
    test_data = load_data_split("./data/bc5cdr-c_v1/processed/test.jsonl")
    #entities = load_entities("./data/bc5cdr-c_v1/entity_documents.json")
    # biosyn = BertModel.from_pretrained(biosyn_path).to(device)
    # biosyn_tokenizer = BertTokenizer.from_pretrained(biosyn_path)
    path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/train/id2idx.pkl"
    train_id2idx = load_pickle(path)
    #print(train_id2idx['D002216'])
    path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/train/biobert_dict.pkl"
    train_entities = load_pickle(path)

    # path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/train/dictionary_embs.pt"
    # # train_ent_embs = torch.load(path)
    path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/val/id2idx.pkl"
    
    val_id2idx = load_pickle(path)
    path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/val/biobert_dict.pkl"
    val_entities = load_pickle(path)
    # path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/val/dictionary_embs.pt"
    # val_ent_embs = torch.load(path)
    
    train_tensor_data = process_mention_data(train_data, tokenizer, train_id2idx, params["debug"])
    val_tensor_data = process_mention_data(val_data, tokenizer, val_id2idx, params["debug"])
    batch_size = params['batch_size']
    val_test_batch_size = params["val_test_batch_size"]
    train_dataloader = DataLoader(train_tensor_data, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_tensor_data, batch_size = val_test_batch_size, shuffle = True)

    num_train_epochs = params['epoch']

    n_candidates = params['n_candidates']
    
    path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/train/neighbors.npy"

    train_neighbors = np.load(path)
    path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/train/neighbor_labels.npy"
    train_neighbor_labels = np.load(path)
    
    
    path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/val/neighbors.npy"
    val_neighbors = np.load(path)
    path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/val/neighbor_labels.npy"
    val_neighbor_labels = np.load(path)
    #reranker = reranker.to(device)
    best_val = 0
    best_model = None
    output = params['output_path']
    #n_candidates = params['n_candidates']
    #reranker.to(device)
    criterion = torch.nn.BCELoss()

    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
        model.train()
        logger.info("Training")
        #torch.cuda.empty_cache()
        gradient_accumulation_steps = 16
        loss_list = []
        correct = 0
        total_samples = 0
        
        for step, batch in enumerate(tqdm(train_dataloader, desc = "Processing minibatches")):
            #optimizer.zero_grad()
            batch = tuple(t.to(device) for t in batch)
            mention_idxs, context_ids = batch    
            candidate_ids = []
            print(mention_idxs.shape, context_ids.shape)
            # for idx in label_idxs:
            #     candidate_ids.append([])
      
            #     for neighbor in train_neighbors[idx]:
   
            #         candidate_ids[-1].append(train_entities[neighbor]['ids'])
            for idx in mention_idxs:
                candidate_ids.append([])
                for neighbor in train_neighbors[idx]:
                    candidate_ids[-1].append(train_entities[neighbor]['biobert_ids'])
                    
            
            #context_ids = context_ids.to(device)
            candidate_ids = torch.tensor(candidate_ids).to(device)
            print("candidate_ids shape",candidate_ids.shape)
            print("context ids shape ", context_ids.shape)
            scores = reranker(context_ids, candidate_ids, device) 
            total_samples += scores.shape[0]
            ## for binary cross entropy loss
            train_labels = torch.FloatTensor(train_neighbor_labels, device)
           
                
            # target = torch.zeros_like(scores)
            # target[:,0] = 1
            # target = target.float().to(device)
            
            # print(train_labels.shape)
            # probs = F.softmax(scores, dim=1)
            # loss =criterion(probs, train_labels)
            # max_idxs = scores.argmax(dim=1)
            # correct = calculate_accuracy(max_idxs, correct)
           
            
            if params['data_parallel']:
                loss.mean().backward()
            else:
                loss.backward()
                
            if (step+1) % gradient_accumulation_steps == 0 or (step+1) == len(train_dataloader):
                loss_list.append(loss.item())
                logger.info("step: {}, accuracy: {}".format(step, correct/total_samples))
                loss = loss/gradient_accumulation_steps
                optimizer.step()
                optimizer.zero_grad()
                
                
        logger.info("Training completed, with total samples: {}".format(total_samples))    
        logger.info("Training, epoch: {}, loss_list: {}, epoch_loss: {}, accuracy: {}".format(epoch_idx,loss_list, sum(loss_list)/len(loss_list), correct/total_samples))
        writer.add_scalar("Trainining loss/epoch 64 candidates", sum(loss_list)/len(loss_list), epoch_idx)
        
        writer.add_scalar("Training accuracy/epoch 64 candidates", correct/total_samples, epoch_idx)
        logger.info("Evaluating") 
        validation_loss, validation_accuracy = evaluate(reranker,val_dataloader, criterion, val_entities,val_neighbors,val_neighbor_labels, device, logger)

        logger.info("Validation, epoch: {}, loss: {}, accuracy: {}".format(epoch_idx, validation_loss, validation_accuracy))
        
        writer.add_scalar("Validation loss/epoch 64 candidates", validation_loss, epoch_idx)
        writer.add_scalar("Validation accuracy/epoch 64 candidates", validation_accuracy, epoch_idx)
        
        # if validation_accuracy > best_val:
        #     best_val = validation_accuracy
        #     best_model = {'model': model.state_dict(),
        #       'optimizer': optimizer.state_dict()}
        #     if not os.path.exists('./model_ckpts/'+output):
        #         os.makedirs('./model_ckpts/'+output)
        #     torch.save(best_model, './model_ckpts/'+output+'/best_model.pt')
            
        # if (epoch_idx+1)%10 == 0:
        #     checkpoint = {'model': model.state_dict(),
        #       'optimizer': optimizer.state_dict()}
        #     torch.save(checkpoint, './model_ckpts/'+output+'/model_ckpt_'+str(epoch_idx)+'.pt')
            
    logger.info("Evaluating on test set")    
    
    path = "/share/project/biomed/hcd/Masked_EL/data/bc5cdr-c_v1/processed/test/dict.pkl"
    test_entities = load_pickle(path)
    test_tensor_data = process_mention_data(test_data, tokenizer, test_id2idx, params['evaluation'], params["debug"])
    test_dataloader = DataLoader(test_tensor_data, batch_size = val_test_batch_size, shuffle = False)
    test_loss, test_accuracy = evaluate(reranker,test_dataloader, criterion, test_entities, test_neighbors, device, logger)
    logger.info("Evaluation loss: {}, Accuracy: {}".format(test_loss, test_accuracy))
    # writer.add_scalar("Test loss", test_loss)
    # writer.add_scalar("Test accuracy", test_accuracy)
    
    writer.flush()
    writer.close()
    
     
if __name__ == "__main__":
    
    parameters = {"n_candidates":64, "val_test_batch_size":512, "gradient_accumulation_steps":32, "debug":False, "learning_rate":1e-6, "bert_model":"/share/project/biomed/hcd/arboEL/models/biobert-base-cased-v1.1", "out_dim":768, "epoch":20, "output_path":"biosyn", "batch_size":8,  "n_gpu":1, "contrastive":False, "pairwise":False, "data_parallel":False}
    main(parameters)

Some weights of the model checkpoint at /share/project/biomed/hcd/arboEL/models/biobert-base-cased-v1.1 were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


len of id2idx 171267


Tokenizing mentions: 100%|██████████| 5157/5157 [00:24<00:00, 212.42it/s]


not added, due to inconsistency: 268
torch.Size([4889, 512])
len of id2idx 171277


Tokenizing mentions: 100%|██████████| 5302/5302 [00:23<00:00, 221.13it/s]


not added, due to inconsistency: 275
torch.Size([5027, 512])


Epoch:   0%|          | 0/20 [00:00<?, ?it/s]2023-05-11 03:32:25,176 INFO Training
2023-05-11 03:32:25,176 INFO Training
2023-05-11 03:32:25,176 INFO Training
2023-05-11 03:32:25,176 INFO Training
Processing minibatches:   0%|          | 0/612 [00:00<?, ?it/s]
Epoch:   0%|          | 0/20 [00:00<?, ?it/s]


torch.Size([8]) torch.Size([8, 512])
candidate_ids shape torch.Size([8, 64, 25])
context ids shape  torch.Size([8, 512])


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)

In [1]:
import pickle
with open('./data/bc5cdr-c/processed/dictionary/id2idx.pkl','rb') as f:
    my_dict = pickle.load(f)
print(my_dict['D002188'])

KeyError: 'D002188'