In [None]:
from modeling_biencoder import BiEncoder_Normer
# import os
# os.environ['CUDA_VISIBLE_DEVICES']="1"
from transformers import AutoTokenizer, AutoModel, AutoConfig
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_metric_learning import miners, losses, distances
from typing import List, Optional, Tuple, Union
import pandas as pd
import numpy as np
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
from torch.optim import AdamW
from transformers import Trainer
import datasets
from safetensors.torch import load_file
import joblib
from sklearn.preprocessing import normalize
import numpy as np
from tqdm import tqdm
import pandas as pd
import glob
from pathlib import Path
import json

In [None]:
## setting required before running
context_encoder_path = 'path_to_context_encoder_initial_checkpoint'
concept_encoder_path = 'path_to_concept_encoder_initial_checkpoint'
ADD_SOFT_TOKEN = True
MAX_LENGTH = "max_length_here"
SAFE_TENSOR_FILE = 'path_to_saved_parameters'
UNIENCODER = False
if UNIENCODER:
    mean_pooling = True
else:
    mean_pooling = False

In [None]:
context_tokenizer = AutoTokenizer.from_pretrained(context_encoder_path)
config = AutoConfig.from_pretrained(context_encoder_path)
biencoder = BiEncoder_Normer(concept_encoder= concept_encoder_path, context_encoder=context_encoder_path, mean_pooling= mean_pooling,projection= 'linear', config = config)
if ADD_SOFT_TOKEN:
    new_special_tokens = {'additional_special_tokens': ['[SOFT]']}
    context_tokenizer.add_special_tokens(new_special_tokens)
    biencoder.context_encoder.resize_token_embeddings(len(context_tokenizer))
state_dict = load_file(SAFE_TENSOR_FILE)
biencoder.load_state_dict(state_dict)

In [None]:
def preprocess_function_test(datapoint):
    if ADD_SOFT_TOKEN:
        special_id = context_tokenizer.additional_special_tokens_ids[0]
    
    model_inputs = {
        "concept_input_ids": None,
        "context_input_ids": None,
        "labels": None
    }
    mentions = str(datapoint['mention'])
    contexts = str(datapoint['sentence'])
    start, end = datapoint['char_pos']
    
    if ADD_SOFT_TOKEN and start is not None:
        contexts = contexts[:start] + '[SOFT] ' + contexts[start:end] + ' [SOFT]' + contexts[end:]
    
    if UNIENCODER:
        context_input_ids = context_tokenizer.encode(contexts, return_tensors= 'pt')[0]
    else:
        context_input_ids = context_tokenizer.encode(mentions,contexts, return_tensors= 'pt')[0]
    
    if ADD_SOFT_TOKEN and start is not None:
        indices = torch.where(context_input_ids == special_id)[0]
        mention_start = indices[0] + 1
        mention_end = indices[1]
    else:
        mention_start = 1
        mention_end = len(context_input_ids) - 1

    model_inputs["context_input_ids"] = context_input_ids.tolist()
    model_inputs["mention_start"] = mention_start
    model_inputs["mention_end"] = mention_end
    
    return model_inputs

def data_collator(batch):
    len_max_batch_context = [len(batch[i].get("context_input_ids"))
                     for i in range(len(batch))]
    len_max_batch_context = min(MAX_LENGTH, max(len_max_batch_context))


    batch_context_input_ids = []
    batch_context_attention_mask = []

    batch_context_mean_inputs_mask = []

    for ba in batch:
        context_input_ids, mention_start, mention_end = ba.get("context_input_ids"), ba.get("mention_start"), ba.get("mention_end")
        
        context_len_padding = len_max_batch_context - len(context_input_ids) 
        context_input_ids = context_input_ids[:len_max_batch_context] + [0] * (context_len_padding)
        context_attention_mask = torch.ones(len_max_batch_context,dtype=torch.long)
        if context_len_padding != 0:
            context_attention_mask[-context_len_padding:] = 0
        tensor_context_input_ids = torch.tensor(context_input_ids, dtype=torch.long)
        batch_context_input_ids.append(tensor_context_input_ids)
        batch_context_attention_mask.append(context_attention_mask)
        
        context_mean_inputs_mask = torch.zeros(len_max_batch_context,dtype=torch.long)
        context_mean_inputs_mask[mention_start: mention_end] = 1
        batch_context_mean_inputs_mask.append(context_mean_inputs_mask)
        
    batch_context_input_ids = torch.stack(batch_context_input_ids)
    batch_context_attention_mask = torch.stack(batch_context_attention_mask)
    batch_context_mean_inputs_mask = torch.stack(batch_context_mean_inputs_mask)

    input_dict = {
                "context_input_ids": batch_context_input_ids,
                "context_attention_mask":batch_context_attention_mask,
                "context_mean_inputs_mask":batch_context_mean_inputs_mask,
                }
    return input_dict

In [None]:
import ast
test_file = 'test_file.csv'
test = pd.read_csv(test_file, keep_default_na= False)
test['char_pos'] = test['char_pos'].apply(ast.literal_eval)

In [None]:
test = datasets.Dataset.from_pandas(test).map(preprocess_function_test)

In [None]:
model_inputs = data_collator(test)

In [None]:
query_embeds = biencoder.query_embedding(model_inputs , device='cuda:0')

In [None]:
## save embeddings
path = 'path_to_embeddings'
with open(path,'wb') as f:
    joblib.dump(query_embeds,f)