In [1]:
import torch
import argparse
import time
import math
import os, sys
import itertools
import pickle
import torch
from tqdm import tqdm


In [2]:
torch.cuda.device_count()

1

In [3]:
data_path = '/share/data/mei-work/kangrui/github/ref-sum/refsum/data/refsum-data/arxiv-aiml-small'

with open(os.path.join(data_path, f'dev.pkl'), 'rb') as f: 
    dev_idx = pickle.load(f)
    f.close()
    
print(len(dev_idx))

1000


In [4]:
data_path = '/share/data/mei-work/kangrui/github/ref-sum/refsum/data/refsum-data/arxiv-aiml-small'

with open(os.path.join(data_path, f'top200_list_dev_BM25_correct.pkl'), 'rb') as f: 
    dev_data = pickle.load(f)
    f.close()
    
print(len(dev_data))

1000


In [5]:
total = 0
for k in dev_idx:
    if k in dev_data:
        total += 1
print('query coverage:', total, total/len(dev_idx))


query coverage: 1000 1.0


In [None]:
full_path="/share/data/mei-work/kangrui/github/ref-sum/refsum/data/refsum-data/arxiv-aiml/full_data_no_embed.pkl"
with open(full_path, 'rb') as f: 
    id2paper = pickle.load(f)
    f.close()

In [None]:
# conpute 
print('computing coverage ...')

total = 0
for k, v in dev_data.items():
    if k in id2paper:
        total += 1
print('query coverage:', total, total/len(dev_data))


total = 0
covered = 0
total_cand = 0
empty_key=[]
for k, v in dev_data.items():
    all_ref_set = set()
    try:
        for sim_idx, sim_ref_list in v['top100']:
            for paper_id in sim_ref_list:
                all_ref_set.add(paper_id)
                total_cand += 1
        for paper_id in all_ref_set:
            if paper_id in id2paper:
                covered += 1
            total += 1
    except:
        empty_key.append(k)
for k in empty_key:
    del dev_data[k]
# print(total_cand, total_cand/len(dev_data))
print('candidate coverage:', covered, total, covered/total)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
# from torch.utils.data import DataLoader
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.autograd import Variable

from transformers import BertModel, BertTokenizer, AutoConfig
from transformers import LongformerConfig, LongformerModel, LongformerTokenizer
import logging

from transformers import AutoTokenizer, AutoModel

In [None]:
class NXTENTRerankDataset(Dataset):

    def __init__(self, id2paper, plm_name=None, cache_dir=None):

        self.data = id2paper

        if plm_name is not None:
            model_choice = plm_name
        
        # init tokenizer
        if cache_dir is None:
            self.tokenizer = AutoTokenizer.from_pretrained(model_choice)        
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir=cache_dir, local_files_only=True)
        self.max_len = 500

    def __len__(self):
        return len(self.list_paper_ref)

    def get_item(self, query_id, candidate_ids):
        # print(self.cite_pair[index])
        if query_id not in self.data:
            query_abstract = ''
        else:
            query_abstract = self.data[query_id]['abstract']
            if query_abstract is None: query_abstract = ''
        query_text=query_abstract.strip().lower()
        
        query_inputs = self.tokenizer.encode_plus(
            query_text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
        )
        query_ids = query_inputs['input_ids']
        query_mask = query_inputs['attention_mask']
        query_token_type_ids = query_inputs["token_type_ids"]

        query:BertInput={
            'ids': torch.tensor(query_ids, dtype=torch.long),
            'mask': torch.tensor(query_mask, dtype=torch.long),
            'token_type_ids': torch.tensor(query_token_type_ids, dtype=torch.long),
#             'text': query_text
        }
            
        ks = []
        kms = []
        for key_id in candidate_ids:
            if key_id not in self.data:
                key_abstract = ''
            else:
                key_abstract = self.data[key_id]['abstract']
                if key_abstract is None: key_abstract = ''

            key_text = key_abstract.strip().lower()
            
            key_inputs = self.tokenizer.encode_plus(
                key_text,
                None,
                add_special_tokens=True,
                max_length=self.max_len,
                padding='max_length',
                return_token_type_ids=True,
                truncation=True,
            )
        
            key_ids = key_inputs['input_ids']
            key_mask = key_inputs['attention_mask']
            key_token_type_ids = key_inputs["token_type_ids"]
            
            ks.append(key_ids)
            kms.append(key_mask)

        key={
            'ids': torch.tensor(ks, dtype=torch.long),
            'mask': torch.tensor(kms, dtype=torch.long),
#             'token_type_ids': torch.tensor(key_token_type_ids, dtype=torch.long),
#             'text': key_text
        }

        data={
            'key':key,
            'query':query
        }

        return data

In [None]:
class BertSiameseClassifier(nn.Module):
    def __init__(self, args, max_length, bert_model=None, prefix_tuning=True, 
                 fine_tuning=True, blank_padding=True):
        super().__init__()
        
        self.args = args
        
#         self.num_class = args.num_classes
        
        self.max_length = max_length
        self.blank_padding = blank_padding
        self.hidden_size = args.hidden_size
        
#         self.prefix_dropout = args.prefix_dropout
#         self.dropout = nn.Dropout(self.prefix_dropout)
        
        self.attentive_bert = False
        
        self.device = 'cuda'
        self.method_name = 'direct_siamese'
#         self.method_names = ['direct_siamese', 'cross_siamese', 'cross_contexts']
        
        prefix_tuning = False
        self.prefix_tuning = prefix_tuning
        self.fine_tuning = fine_tuning
        
        print('Prefix-tuning:', self.prefix_tuning)
        print('Fine-tuning:', self.fine_tuning)
        
        # Token-level attention (serve as a head)
        self.attention_fc = nn.Linear(self.hidden_size, 1, bias=False)

        if bert_model is None:
            logging.info('Loading BERT pre-trained checkpoint.')
            self.bert = BertModel.from_pretrained("bert-base-uncased")
        else:
            self.bert = bert_model
            
        self.bert.gradient_checkpointing_enable()

        self.loss = nn.CrossEntropyLoss()
        
    def pred_vars(self):
        """
        Return the variables of the predictor.
        """
        params = list(
            self.bert.parameters()) + list(
            self.attention_fc.parameters())
        
        return params
    
    def forward(self, ss, sms, ts, tms, use_context=True):
        """
        Args:
            ss: (B, L), index of tokens
            sms: (B, L), index of action tokens
            ts: (B, L), index of tokens
            tms: (B, L), index of action tokens
        """
        LARGE_NEG = -1e9
        
        s_hiddens = self.bert(ss, attention_mask=sms)
        s_hiddens = s_hiddens[0] # (B, Ls, H)
        
        t_hiddens = self.bert(ts, attention_mask=tms)
        t_hiddens = t_hiddens[0] # (B, Lt, H)
        
#         s_hiddens = s_hiddens[:,0,:]
#         t_hiddens = t_hiddens[:,0,:]

        s_hiddens = torch.tanh(s_hiddens)
        t_hiddens = torch.tanh(t_hiddens)
        
        if self.method_name == 'direct_siamese':
            s_att_logits = self.attention_fc(s_hiddens).squeeze(-1) # (B, L) 
            s_att_logits = s_att_logits + (1. - sms)*LARGE_NEG # (B, L)
            s_att = F.softmax(s_att_logits, dim=-1) # (B, L)
            s_hiddens = torch.sum(s_hiddens * s_att.unsqueeze(-1), dim=1) # (B, H)

            t_att_logits = self.attention_fc(t_hiddens).squeeze(-1) # (B, L) 
            t_att_logits = t_att_logits + (1. - tms)*LARGE_NEG # (B, L)
            t_att = F.softmax(t_att_logits, dim=-1) # (B, L)
            t_hiddens = torch.sum(t_hiddens * t_att.unsqueeze(-1), dim=1) # (B, H)
            
            pred_logits = torch.mm(s_hiddens, t_hiddens.transpose(0, 1)) # (B, B)
        else:
            return None

        return pred_logits

In [None]:
from transformers import BertModel
from torch import nn
import torch


class DoubleBERT(nn.Module):
    def __init__(self,PLM_NAME):
        super(DoubleBERT, self).__init__()
        # configuration = BertConfig()
        self.query_bert = BertModel.from_pretrained(PLM_NAME)
        self.key_bert = BertModel.from_pretrained(PLM_NAME)
        
        self.key_bert.gradient_checkpointing_enable()
        self.query_bert.gradient_checkpointing_enable()

        # self.rawoutput=cfg.model_arch.rawoutput

        # print(self.key_bert)
    def forward(self, ss, sms, ts, tms, use_context=True):
        """
        Args:
            ss: (B, L), index of tokens
            sms: (B, L), index of action tokens
            ts: (B, L), index of tokens
            tms: (B, L), index of action tokens
        """
        LARGE_NEG = -1e9
        
        s_hiddens = self.query_bert(ss, attention_mask=sms)
        s_hiddens = s_hiddens[1] # (B, Ls, H)
        
        t_hiddens = self.key_bert(ts, attention_mask=tms)
        t_hiddens = t_hiddens[1] # (B, Lt, H)
        
            
        pred_logits = torch.mm(s_hiddens, t_hiddens.transpose(0, 1)) # (B, B)
       

        return pred_logits
    
#     def forward(self, input,mode=None):
#         if mode!=None: # eval mode
#             if mode=='key':
#                 key_out=self.key_bert(input['ids'], attention_mask = input['mask'], token_type_ids = input['token_type_ids'])
#                 return key_out[1]
#             elif mode=='query':
#                 query_out=self.query_bert(input['ids'], attention_mask = input['mask'], token_type_ids = input['token_type_ids'])
#                 return query_out[1]
#         key_out=self.key_bert(input['key']['ids'], attention_mask = input['key']['mask'], token_type_ids = input['key']['token_type_ids'])
#         query_out=self.query_bert(input['query']['ids'], attention_mask = input['query']['mask'], token_type_ids = input['query']['token_type_ids'])
#         key_pooler_output=key_out[1]
#         query_pooler_output=query_out[1]

#         key_embedding=key_pooler_output
#         query_embedding=query_pooler_output
#         return {'key':key_embedding,'query':query_embedding}

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
PLM_NAME = 'bert-base-uncased'

In [None]:
model = DoubleBERT(PLM_NAME).to(device)

In [None]:
import os.path as osp
from collections import OrderedDict
import torch
import torch.nn
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel as DDP
def load_network(model,network_pth_path,device,loaded_net=None,name=None):
        add_log = False
        if loaded_net is None:
            add_log = True
            loaded_net = torch.load(
                network_pth_path,
                map_location=torch.device(device),
            )
        loaded_clean_net = OrderedDict()  # remove unnecessary 'module.'
        for k, v in loaded_net.items():
            if k.startswith("module."):
                loaded_clean_net[k[7:]] = v
            else:
                loaded_clean_net[k] = v

        model.load_state_dict(loaded_clean_net, strict=False)
        return model

In [None]:
network_pth_path="/share/data/mei-work/kangrui/github/ref-sum/refsum/testmulti/bert/2023-03-14T00-07-12/checkpoints/best.pt"
model=load_network(model,network_pth_path,device)

In [None]:
val_dataset=NXTENTRerankDataset(id2paper, cache_dir=None, plm_name=PLM_NAME)

In [None]:
def run(RANK_SET_SIZE):
    num_overlap = 0
    num_predict = 0
    num_gt = 0
    print(f"RANK_SET_SIZE: {RANK_SET_SIZE}")
    for k, v in dev_data.items():
    #     print(k, v)
        gt_ref_set = v['gt_ref']
    #     print(gt_ref)

        all_ref_set = set()
        ref_freq_dict = {}
        total = 0
        for sim_idx, sim_ref_list in v['top100']:
    #         all_ref_set.add(sim_idx)
            for paper_id in sim_ref_list:
                all_ref_set.add(paper_id)
                if paper_id not in ref_freq_dict:
                    ref_freq_dict[paper_id] = 1
                else:
                    ref_freq_dict[paper_id] += 1
    #         total += len(sim_ref_list)
            total += 1

        sorted_ref = sorted(ref_freq_dict.items(), key=lambda x:x[1], reverse=True)
    #     sorted_ref_set = {k for k,v in sorted_ref[:len(gt_ref_set)]}
        sorted_ref_set = {k for k,v in sorted_ref[:RANK_SET_SIZE]}
    #     sorted_ref_set = {k for k in list(all_ref_set)[:2048]}

    #     overlap_set = all_ref_set.intersection(gt_ref_set)
        overlap_set = sorted_ref_set.intersection(gt_ref_set)
    #     print(len(overlap_set))

        num_gt += len(gt_ref_set)
    #     num_predict += len(all_ref_set)
        num_predict += len(sorted_ref_set)
        num_overlap += len(overlap_set)

    #     break

    print('average number of gt ref:', num_gt / len(dev_data))
    print('average number of predicted ref:', num_predict / len(dev_data))

    prec, rec = num_overlap / num_predict, num_overlap / num_gt
    print('precision: {:.4f} recall: {:.4f} f1: {:.4f}'.format(prec, rec, 2 * prec * rec/(prec+rec)))
    
    num_overlap = 0
    num_predict = 0
    num_gt = 0

    pbar = tqdm(dev_data.items(), postfix=f"Testing")

    # all_data_tuples = []

    model.eval()
    with torch.no_grad():
        for query_id, v in pbar:
            gt_ref_set = v['gt_ref']

            all_ref_set = set()
            ref_freq_dict = {}
            total = 0
            for sim_idx, sim_ref_list in v['top100']:
        #         all_ref_set.add(sim_idx)
                for paper_id in sim_ref_list:
                    all_ref_set.add(paper_id)
                    if paper_id not in ref_freq_dict:
                        ref_freq_dict[paper_id] = 1
                    else:
                        ref_freq_dict[paper_id] += 1
        #         total += len(sim_ref_list)
                total += 1

            sorted_ref = sorted(ref_freq_dict.items(), key=lambda x:x[1], reverse=True)
    #         sorted_ref_set = {k for k,v in sorted_ref[:RANK_SET_SIZE]}
    #         sorted_ref_list = list(sorted_ref_set)
            sorted_ref_list = [k for k,v in sorted_ref[:RANK_SET_SIZE]]
    #         sorted_ref_set = set(sorted_ref_list)

            data = val_dataset.get_item(query_id, sorted_ref_list)

            ss = data["query"]['ids'].to(device).unsqueeze(0)
            sms = data["query"]['mask'].to(device).unsqueeze(0)
            ts = data["key"]['ids'].to(device)
            tms = data["key"]['mask'].to(device)

    #         all_data_tuples.append((ss, sms, ts, tms, query_id, sorted_ref_list))

            pred_logits = model(ss, sms, ts, tms)

            min_k = min(len(gt_ref_set), len(sorted_ref_list))
            topk = torch.topk(pred_logits[0], k=min_k)[1]
            pred_set = {sorted_ref_list[idx.cpu().item()] for idx in topk}

    #         pred_set = {sorted_ref_list[idx] for idx in range(min_k)}
    #         pred_set = {sorted_ref_list[idx] for idx in range(len(gt_ref_set))}
    #         pred_set = set(sorted_ref_list[:len(gt_ref_set)])
    #         print(pred_set)
    #         pred_set = {k for k,v in sorted_ref[:len(gt_ref_set)]}
    #         print(pred_set)
    #         print(set(list(pred_set)))

            overlap_set = pred_set.intersection(gt_ref_set)

            num_gt += len(gt_ref_set)
            num_predict += len(pred_set)
            num_overlap += len(overlap_set)

            prec, rec = num_overlap / num_predict, num_overlap / num_gt
            f1 = 2 * prec * rec/(prec+rec)

            pbar.postfix = "Testing: prec-{:.4f} rec-{:.4f} f1-{:.4f}".format(prec, rec, f1)

    #         break


    print('average number of gt ref:', num_gt / len(dev_data))
    print('average number of predicted ref:', num_predict / len(dev_data))

    prec, rec = num_overlap / num_predict, num_overlap / num_gt
    print('precision: {:.4f} recall: {:.4f} f1: {:.4f}'.format(prec, rec, 2 * prec * rec/(prec+rec)))

    