In [7]:
clean = False

# srf-mistral 进行 embedding 编码 

模型的总参数量: 7110660096

In [8]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import gc 
import ctypes
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

import json
import pandas as pd 
import re
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import numpy as np 
import pickle

def clean_memory():
    gc.collect()
    ctypes.CDLL("libc.so.6").malloc_trim(0)
    torch.cuda.empty_cache()

In [9]:
BATCH_SIZE = 32

## question embedding 

In [10]:
def last_token_pool(last_hidden_states: Tensor,
                    attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'


def encode_questions_SFR_Embedding_Mistral(df): 
    df = df.fillna('')
    
    ## model 
    model_path = '/mntdata/wangql43/A000Files/A003Model/recallModel/SFR-Embedding-Mistral/'
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path)
    model.eval()
    model.to(f'cuda') 
    
    ## encode 
    # Each query must come with a one-sentence instruction that describes the task
    # task = 'Given Question and The Detailed Analysis of the Question, retrieve most relevant Title and Abstract that answer the Question.'
    task = 'Please retrieve and provide the most relevant title and abstract of literature based on the user’s specific question and its detailed analysis. Ensure a deep understanding of the question’s underlying meaning, and prioritize the accuracy and relevance of the information in the returned results.'
    # task = 'Retrieve the most pertinent title and abstract addressing the user’s question and its analysis, emphasizing understanding, accuracy, and relevance.'
    
    ## df -- Query 
    queries = []
    for _, row in df.iterrows():
        text =  f"Question: {row.question}\n\n The Detailed Analysis of the Question: {row.body}"
        queries.append(get_detailed_instruct(task, text)) 
    
    ## dl 
    max_length = 2048
    input_texts = queries 
    dataloader = DataLoader(
        input_texts, batch_size=BATCH_SIZE, num_workers=16,
        collate_fn=lambda batch: tokenizer(batch, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
    )
    
    ## model 
    embeddings = []
    with torch.no_grad():
        with torch.autocast(device_type='cuda'):
            for batch in tqdm(dataloader):
                model_output = model(**batch.to(model.device))
                sentence_embeddings = last_token_pool(model_output.last_hidden_state, batch['attention_mask'])
                sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
                embeddings.append(sentence_embeddings)
    embeddings = torch.cat(embeddings, dim=0).cpu().numpy().astype(np.float16)
    
    clean_memory()
        
    return embeddings

In [11]:
## 读取数据的内容 
def read_train_valid_test(path): 
    data = []
    
    assert path.endswith('.txt')
    # 打开并逐行读取txt文件
    with open(path, 'r') as f:
        for line in f:
            # 使用json.loads将每一行转换为字典
            data.append(json.loads(line))
            
    ## 转换成 df 格式 
    data = pd.DataFrame(data)
    return data 

def read_json_to_df(json_path): 
    # 打开json文件
    with open(json_path, 'r') as file:
        # 解析json文件
        data = json.load(file)
    ## json --> df 
    data = pd.DataFrame(data).T.reset_index(names=['pids'])
    return data


## 清洗数据 
def clean_body_remove_symbol(text): 
    ## clean_body_remove_symbol(text) 
    text = re.sub('<[^<]+?>', ' ', text).replace('\n', '').strip()
    text = re.sub(r'\s+', ' ', text)
    text = text.replace('http://', '').replace('https://', '').replace('.com', '').replace('.cn', '')
    return text 

In [12]:
## 
testpath = 'data/AQA-test-public/qa_test_wo_ans_new.txt'
test = read_train_valid_test(testpath)
(test['question'] + '\n' + test['body']).apply(lambda x : len(x.split(' '))).describe()

## 是否需要清洗 
if clean: 
    test['body'] = test['body'].apply(clean_body_remove_symbol)

In [14]:
clean_memory()
if clean: 
    srf_mistral_embeddings = encode_questions_SFR_Embedding_Mistral(test) 
    with open(f'outslgb/encoded_question_srf_mistral_test.pkl', 'wb') as f:
        pickle.dump(srf_mistral_embeddings, f)
    clean_memory()
else: 
    srf_mistral_embeddings = encode_questions_SFR_Embedding_Mistral(test) 
    with open(f'outslgb/encoded_question_srf_mistral-NoClean_test.pkl', 'wb') as f:
        pickle.dump(srf_mistral_embeddings, f)
    clean_memory()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/94 [00:01<?, ?it/s]

In [1]:
print('Finish !!!')

Finish !!!
