+ results: 存放最终预测结果文件夹
+ data: 赛题数据和其他数据文件夹
+ models: huggingface模型权重文件夹

In [1]:
!mkdir -p results
!mkdir -p data
!mkdidr -p models

In [1]:
import json
import faiss
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from tqdm import tqdm
import numpy as np

def read_json(data_path: str) -> dict:
    with open(data_path, "r", encoding="utf-8") as fh:
        return json.load(fh)
        
def read_txt(data_path: str) -> dict:
    samples = []
    with open(data_path, "r", encoding="utf-8") as fh:
        for line in fh:
            samples.append(json.loads(line))
    return samples

def write_txt(samples, data_path: str) -> dict:
    with open(data_path, "w", encoding="utf-8") as fh:
        for sample in samples:
            fh.writelines(sample+'\n')
            
def load_jsonl(file):
    samples = []
    with open(file, "r", encoding="utf-8") as fh:
        for i, line in enumerate(fh):
            line = line.strip()
            if not line:
                continue
            sample = json.loads(line)
            samples.append(sample)   
    return samples
    
def head(data, n=5):
    keys = np.random.choice(list(data.keys()), n)
    for k in keys:
        print(data[k].keys())
        print(data[k])
        print('='*50)

In [2]:
data = read_json('data/AQA/AQA-test-public/pid_to_title_abs_update_filter.json')
train_data = read_txt('data/AQA/qa_train.txt')
val_data = read_txt('data/AQA/qa_valid_wo_ans.txt')
test_data = read_txt('data/AQA/AQA-test-public/qa_test_wo_ans_new.txt')

In [4]:
device = torch.device('cuda:0')
model_name = 'NV-Embed-v1' #
model_path = f'models/{model_name}'
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
config.text_config._name_or_path = model_path
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, config=config, 
                              trust_remote_code=True, torch_dtype=torch.float16).to(device)
model.eval()
print(f'#parameters: {sum([p.numel() for p in model.parameters()])/1e6}M')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

#parameters: 7851.016192M


# 使用额外的DBLP数据库提取关键词

In [5]:
database = load_jsonl('data/DBLP-Citation-network-V15.json')

database_dict = {} 
for d in database:
    database_dict[d.pop('id')] = d
    hit = []
    
for i in tqdm(data):
    if i in database_dict:
        hit.append(i)
        data[i]['keywords'] = ','.join(database_dict[i]['keywords'])
    else:
        data[i]['keywords'] = ''

100%|██████████| 466387/466387 [00:00<00:00, 789858.09it/s]


# 推理检索

In [None]:
def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'
    
def get_paper_string(pid):
    title = data[pid]['title']
    abstract = data[pid]['abstract']
    keywords = data[pid]['keywords']
    
    query = ''
    if title is None:
        query += f''
    else:
        query += f'{title}'
        
    if abstract is None:
        query += f''
    else:
        query += f'{abstract}'   
        
    query += keywords
    return query

def get_question_string(d):
    question = d['question']
    body = d['body']
    query = ''
    
    if question is None:
        query += f''
    else:
        query += f'{question}'
    if body is None:
        query += f''
    else:
        query += f'{body}'
        
    # Each query must come with a one-sentence instruction that describes the task
    task = 'Given a question, retrieve passages that answer the question' 
    query = get_detailed_instruct(task, query)
    return query

pids = []
data_query = []
for pid in tqdm(data):
    query = get_paper_string(pid)
    data_query.append(query)
    pids.append(pid)
    
pids = np.array(pids)      

val_question_query = []
for d in tqdm(val_data):
    query = get_question_string(d)
    val_question_query.append(query)

test_question_query = []
for d in tqdm(test_data):
    query = get_question_string(d)
    test_question_query.append(query)
    
train_question_query = []
train_pids = []
for d in tqdm(train_data):
    query = get_question_string(d)
    for p in d['pids']:
        train_question_query.append(query)
        train_pids.append(p)
train_pids = np.array(train_pids)      

In [7]:
def mean_pooling(token_embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, dim=1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def last_token_pool(last_hidden_states, attention_mask):
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


@torch.no_grad()
def get_embedding(corpus, tokenizer, model, device, instruction, max_length=4096):   
    e = model.encode(corpus, instruction=instruction, max_length=max_length)
    e = e.float()
    return e

def generate_embedding(texts, tokenizer, model, device, instruction='', batch_size=32):
    embedding = []
    for batch in tqdm(
            [
                texts[i: i + batch_size]
                for i in range(0, len(texts), batch_size)
            ],
            desc="Generating embedding"
    ):
        e = get_embedding(batch, tokenizer, model, device, instruction)
        embedding.append(e.cpu())
    embedding = torch.cat(embedding)
    return embedding

def drop_dup(x, k=20):
    d = {}
    res = []
    for i in x:
        if i not in d:
            d[i] = 1
            res.append(i)
            if len(res) == k:
                return res
    return res

In [None]:
batch_size = 2
embeddings_corpus = generate_embedding(data_query, tokenizer, model, device, batch_size=batch_size)
embeddings_train = generate_embedding(train_question_query, tokenizer, model, device, batch_size=batch_size)
embeddings_val = generate_embedding(val_question_query, tokenizer, model, device, batch_size=batch_size)
embeddings_test = generate_embedding(test_question_query, tokenizer, model, device, batch_size=batch_size)

  'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),
Generating embedding:   2%|▏         | 4361/233194 [09:47<10:46:49,  5.90it/s]

In [None]:
def build_index(x):
    dimension = x.shape[-1]
    # index = faiss.IndexFlatL2(dimension)
    index = faiss.IndexFlatIP(dimension)
    index.add(x.float().cpu())
    return index
    
def search(index, x, k=20):
    distance, topk = index.search(x.float().cpu(), k)
    return distance, topk

In [None]:
for stage in ['val', 'test']:
    print(stage)
    if stage == 'val':
        embedding = embeddings_val
    else:
        embedding = embeddings_test
        
    distance, topk = search(build_index(embeddings_corpus), embedding, 40)
    topk = pids[topk]
    
    sorted_results = topk
    
    results = [drop_dup(r) for r in sorted_results]
    for i in range(len(results)):
        assert np.unique(results[i]).shape[0] == 20
    
    write_txt([','.join(r) for r in results], f'results/{stage}_result_{model_name}.txt')
    torch.save({'distance': distance, 'topk': topk}, 
               f'results/{stage}_result_{model_name}.pt')


In [None]:
print(f'#parameters: {sum([p.numel() for p in model.parameters()])}')