In [None]:
# !pip install numpy
# !pip install pandas
# !pip install transformers
# !pip install sentence-transformers
# !pip install cogdl

In [None]:
import os
import time
import random
import numpy as np
import pandas as pd
import gc

from cogdl.oag import oagbert
import torch
from torch import cuda
import torch.nn.functional as F
from sentence_transformers import util

In [None]:
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)

def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
SEED = 2024
fix_seed(SEED)

path = 'path_to_context_feature'

In [None]:
df_train_context = pd.read_csv(os.path.join(path, 'df_train_context_filled_keywords.csv'))

In [None]:
df_train_context.head()

In [None]:
emb_list = []
emb2_list = []
cossim_list = []

tokenizer, model = oagbert("oagbert-v2-sim")
# model.bert.to(device)
model.to(device)
model.eval()

start_time = time.time()

for i, row in df_train_context.iterrows():
    if i%10==0:
        print(i, time.time() - start_time)
        gc.collect()
        torch.cuda.empty_cache()
    
    # encode a paper
    (
        input_ids,
        input_masks,
        token_type_ids,
        masked_lm_labels,
        position_ids,
        position_ids_second,
        masked_positions,
        num_spans,
    ) = model.build_inputs(
        title=str(row['title']), abstract=str(row['abstract']), venue=str(row['venue']), concepts=str(row['keywords'])
    )
    _, paper_embed_src = model.bert.forward(
        input_ids=torch.LongTensor(input_ids).unsqueeze(0).to(device),
        token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0).to(device),
        attention_mask=torch.LongTensor(input_masks).unsqueeze(0).to(device),
        output_all_encoded_layers=False,
        checkpoint_activations=False,
        position_ids=torch.LongTensor(position_ids).unsqueeze(0).to(device),
        position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0).to(device),
    )
    
    (
        input_ids,
        input_masks,
        token_type_ids,
        masked_lm_labels,
        position_ids,
        position_ids_second,
        masked_positions,
        num_spans,
    ) = model.build_inputs(
        title=str(row['ref_title']), abstract=str(row['ref_abstract']), venue=str(row['ref_venue']), concepts=str(row['ref_keywords'])
    )
    _, paper_embed_src2 = model.bert.forward(
        input_ids=torch.LongTensor(input_ids).unsqueeze(0).to(device),
        token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0).to(device),
        attention_mask=torch.LongTensor(input_masks).unsqueeze(0).to(device),
        output_all_encoded_layers=False,
        checkpoint_activations=False,
        position_ids=torch.LongTensor(position_ids).unsqueeze(0).to(device),
        position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0).to(device),
    )
    
    temp_cos_sim = util.cos_sim(paper_embed_src, paper_embed_src2).detach().cpu().numpy()[0][0]
    emb_list.append(paper_embed_src.detach().cpu().numpy())
    emb2_list.append(paper_embed_src2.detach().cpu().numpy())
    cossim_list.append(temp_cos_sim)
    
# np.save('train_context_emb1.npy', np.array(emb_list).reshape(-1,768))
# np.save('train_context_emb2.npy', np.array(emb2_list).reshape(-1,768))
np.save('train_context_cossim.npy', np.array(cossim_list).reshape(-1,1))

In [None]:
df_test_pub_gen_context = pd.read_csv(os.path.join(path, 'test_pub_gen_context_filled_citation.csv'))

In [None]:
emb_list = []
emb2_list = []
cossim_list = []

tokenizer, model = oagbert("oagbert-v2-sim")
# model.bert.to(device)
model.to(device)
model.eval()

start_time = time.time()

for i, row in df_test_pub_gen_context.iterrows():
    if i%10==0:
        print(i, time.time() - start_time)
        gc.collect()
        torch.cuda.empty_cache()
    
    # encode a paper
    (
        input_ids,
        input_masks,
        token_type_ids,
        masked_lm_labels,
        position_ids,
        position_ids_second,
        masked_positions,
        num_spans,
    ) = model.build_inputs(
        title=str(row['title']), abstract=str(row['abstract']), venue=str(row['venue']), concepts=str(row['keywords'])
    )
    _, paper_embed_src = model.bert.forward(
        input_ids=torch.LongTensor(input_ids).unsqueeze(0).to(device),
        token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0).to(device),
        attention_mask=torch.LongTensor(input_masks).unsqueeze(0).to(device),
        output_all_encoded_layers=False,
        checkpoint_activations=False,
        position_ids=torch.LongTensor(position_ids).unsqueeze(0).to(device),
        position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0).to(device),
    )
    
    (
        input_ids,
        input_masks,
        token_type_ids,
        masked_lm_labels,
        position_ids,
        position_ids_second,
        masked_positions,
        num_spans,
    ) = model.build_inputs(
        title=str(row['ref_title']), abstract=str(row['ref_abstract']), venue=str(row['ref_venue']), concepts=str(row['ref_keywords'])
    )
    _, paper_embed_src2 = model.bert.forward(
        input_ids=torch.LongTensor(input_ids).unsqueeze(0).to(device),
        token_type_ids=torch.LongTensor(token_type_ids).unsqueeze(0).to(device),
        attention_mask=torch.LongTensor(input_masks).unsqueeze(0).to(device),
        output_all_encoded_layers=False,
        checkpoint_activations=False,
        position_ids=torch.LongTensor(position_ids).unsqueeze(0).to(device),
        position_ids_second=torch.LongTensor(position_ids_second).unsqueeze(0).to(device),
    )
    
    temp_cos_sim = util.cos_sim(paper_embed_src, paper_embed_src2).detach().cpu().numpy()[0][0]
    emb_list.append(paper_embed_src.detach().cpu().numpy())
    emb2_list.append(paper_embed_src2.detach().cpu().numpy())
    cossim_list.append(temp_cos_sim)
    
# np.save('test_pub_gen_context_emb1.npy', np.array(emb_list).reshape(-1,768))
# np.save('test_pub_gen_context_emb2.npy', np.array(emb2_list).reshape(-1,768))
np.save('test_pub_gen_context_cossim.npy', np.array(cossim_list).reshape(-1,1))