In [5]:
import torch
import os
import sys
sys.path.append('../')

from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer
from models.config.defaults import cfg
from models.LM import LM
from models.data import DATA

def inference(model, transformer, tokenizer, device, test_data, threshold = 0.5, result_path = './result.csv', batch_size = 32):
    
    # process test data
    text, id_1s, id_2s= test_data['text'], test_data['id_1'], test_data['id_2']
        
    text_t = []
    for i in tqdm(range(len(text))):
        tokens = tokenizer(
                text[i],
                add_special_tokens=True,
                padding = 'max_length',
                truncation = True,
                return_offsets_mapping = False,
                max_length = cfg.DATA.PREPROCESS_MAX_LEN,
                return_token_type_ids = False,
                return_attention_mask = False,
                return_tensors = 'pt',
            )['input_ids'].tolist()
        text_t.append(tokens)
    print(len(text_t))
    batch_size = min(batch_size, len(text_t))
    text_t = torch.tensor(text_t).reshape(len(text_t), -1)

    # id2int : conver id to integer 
    id2int = {}
    int2id = list(set(id_1s+id_2s))
    for i, id_1 in enumerate(int2id):
        id2int[id_1] = i

    # convert id_1s and id_2s to integer
    id_1s = torch.tensor([id2int[id_1] for id_1 in id_1s]).reshape(-1, 1)
    id_2s = torch.tensor([id2int[id_2] for id_2 in id_2s]).reshape(-1, 1)

    bg = DataLoader(TensorDataset(text_t, id_1s, id_2s), batch_size = batch_size, shuffle = False)
    
    dict_ans = {ent:[ent] for ent in int2id}
    with torch.no_grad():
        with tqdm(total=len(text_t)) as pbar:
            for idx, (text, id_1, id_2) in enumerate(bg):
                text = text.to(device)
                id_1 = id_1.to(device)
                id_2 = id_2.to(device)
                output = transformer(input_ids = text)['hidden_states'][-1]
                output = torch.mean(output, dim=1)
                output = output.reshape(output.shape[0], -1)
                predict_res = model.predict(output, threshold = threshold)
                # filter the id_1s and id_2s which get predict_res as 1
                id_1_t = id_1[predict_res == 1]
                id_2_t = id_2[predict_res == 1]

                pbar.update(batch_size)

                # record the result 
                for i in range(len(id_1_t)):
                    dict_ans[int2id[id_1_t[i]]].append(int2id[id_2_t[i]])
                
    with open(result_path, 'w') as f:
        for ent in dict_ans:
            f.write(ent + ',' + ' '.join(dict_ans[ent]) + '\n')


In [6]:
cfg.MODEL.IS_TRAIN = False
cfg.DATA.DATA_SAVED = False

# set the device
if cfg.MODEL.DEVICE == "cpu":
    device = torch.device(cfg.MODEL.DEVICE)
else:
    device = torch.device(cfg.MODEL.DEVICE  if torch.cuda.is_available() else 'cpu')

# load model
model = LM(cfg)
model.load_model(cfg.TEST.MODEL_PATH)

# set the model to device
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(cfg.MODEL.PRETRAINED_MODEL_PATH)
transformer = AutoModelForMaskedLM.from_pretrained(
    cfg.MODEL.PRETRAINED_MODEL_PATH, 
    output_hidden_states=True, 
    output_attentions=True).to(device)
for param in transformer.parameters():
    param.requires_grad = False

# load data 
D = DATA(cfg)
test_data = D.get_test_data_dict(auto_gen=True, features = cfg.DATA.FEATURES[:-1], rounds = cfg.TEST.ROUNDS, n_neighbors=cfg.TEST.N_NEIGHBORS)

# inference on test dataset
inference(model, transformer, tokenizer, device, test_data, threshold = cfg.TEST.BEST_THRESHOLD, result_path = cfg.TEST.RESULT_PATH, batch_size = cfg.TEST.BATCH_SIZE)

Loading data..., data path:  ../dataset
test_data:  5
generating test data...


100%|██████████| 5/5 [00:00<00:00, 16817.58it/s]
100%|██████████| 5/5 [00:00<00:00, 2868.49it/s]
100%|██████████| 5/5 [00:00<00:00, 2736.37it/s]
100%|██████████| 5/5 [00:00<00:00, 3290.17it/s]
100%|██████████| 15/15 [00:00<00:00, 3773.21it/s]


test_data_list:  15
organizing test_data_list as dictionary format: {'text': text, 'num_entities': num_entities}


100%|██████████| 15/15 [00:00<00:00, 22176.44it/s]
100%|██████████| 15/15 [00:00<00:00, 915.04it/s]


15


100%|██████████| 15/15 [00:08<00:00,  1.87it/s]
