LLM模型准备

In [1]:
w_IsUSE = 1
w_IsSUP = 1
w_IsREL = 1
use_IsUSE = True
use_IsSUP = True
use_IsREL = True
show_details = True

In [None]:
import sys
# 添加了新的查询路径
sys.path.append("self-rag/retrieval_lm/")
from passage_retrieval import Retriever
import numpy as np

class Retriever_LLM:
    def __init__(self, top_k:int = 5):
        self.retriever = Retriever({})
        self.retriever.setup_retriever_demo("self-rag/retrieval_lm/contriever", "self-rag/retrieval_lm/enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl", "self-rag/retrieval_lm/enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*",  n_docs=5, save_or_load_index=False)
    def search(self, query, n_docs:int = 5):
        retrieved_documents = self.retriever.search_document_demo(query, n_docs)
        return retrieved_documents

retriever = Retriever_LLM(5)

Loading model from: self-rag/retrieval_lm/contriever


Some weights of the model checkpoint at self-rag/retrieval_lm/contriever were not used when initializing Contriever: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing Contriever from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Contriever from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Indexing passages from files ['self-rag/retrieval_lm/enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/passages_00', 'self-rag/retrieval_lm/enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/passages_01', 'self-rag/retrieval_lm/enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/passages_02', 'self-rag/retrieval_lm/enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/passages_03']
Loading file self-rag/retrieval_lm/enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/passages_00


In [None]:
class Response_LLM:
    def __init__(self, token_ids, text, logprobs):
        self.token_ids = token_ids
        self.text = text
        self.logprobs = logprobs
    

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams
from utils import load_special_tokens

class Model_LLM:
    def __init__(self, local_model_path: str, max_tokens: int, skip_special_tokens: bool, logprobs: int, tokenizer):
        self.model = LLM(model=local_model_path, dtype="half")
        self.sampling_params = SamplingParams(temperature=0.0, 
                                              top_p=1.0, top_k = -1, max_tokens=max_tokens, skip_special_tokens=skip_special_tokens, 
                                              logprobs = logprobs, stop = ["[Retrieval]"])
        self.ret_tokens, self.rel_tokens, self.sup_tokens, self.use_tokens = load_special_tokens(
        tokenizer, use_grounding=use_IsSUP, use_utility=use_IsUSE)
        self.tokenizer = tokenizer
        
    # 判断是否需要进行检索
    # params
    # response:Response_LLM 模型的输出
    # return
    # bool 是否需要检索
    def need_retrieve(self, response: Response_LLM):
        if 32001 in response.token_ids:
            return True
        return False

    # 生成规范查询语句
    # params
    # input:str 输入
    # paragraph:str 查询内容
    # return
    # str:规范的prompt
    def format_prompt(self, input, paragraph=None):
        prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
        if paragraph is not None:
            prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
        return prompt
        
    # 询问大模型（底层）
    # params
    # query: str 模型输入
    # return
    # Response_LLM 模型的输出
    def query_llm(self, query):
        prompt = [query]
        preds = self.model.generate(prompt, self.sampling_params)
        pred_token_ids = preds[0].outputs[0].token_ids
        pred_text = preds[0].outputs[0].text
        pred_log_probs = []
        for logprob in preds[0].outputs[0].logprobs:
            tmp_log_probs = {}
            for key, value in logprob.items():
                tmp_log_probs[key] = np.exp(float(value))
            pred_log_probs.append(tmp_log_probs)
        response =  Response_LLM(pred_token_ids, pred_text, pred_log_probs)
        return response
    # 需要查询时询问大模型
    # params
    # prompt:str 之前的查询的问题+之前生成的文本
    # document:list[str] Retrieval返回的top-k个文本
    # return
    # Response_LLM:所有生成结果中最好的
    def re_query(self, prompt, documents):
        max_score = 0
        best_response = None
        for document in documents:
            response = self.query_llm(self.format_prompt(prompt, document))
            cur_score = self.eval_generation(response)
            if show_details:
                print("此次检索的结果如下")
                print(document)
                print(cur_score)
                print(response.token_ids)
                print(response.text)
            if (cur_score > max_score):
                max_score = cur_score
                best_response = response
        return best_response
    # 根据token评估查询结果
    # params
    # response:Response_LLM 模型输出
    # return
    # float 输出评分
    def eval_generation(self, response:Response_LLM):
        sup_score = 0.0
        if self.sup_tokens is not None:
            num = 0
            for tok_idx, tok in enumerate(response.token_ids):
                if tok in list(self.sup_tokens.values()):
                    token = self.tokenizer.convert_ids_to_tokens(tok)
                    if token == "[Fully supported]":
                        sup_score += response.logprobs[tok_idx][tok] * 0.5 + 0.5
                    if token == "[Partially supported]":
                        sup_score += response.logprobs[tok_idx][tok] * 0.5
                    num += 1
            if num != 0:
                sup_score /= num
            else:
                sup_score = 0.0
        else:
            sup_score = 0.0
        rel_score = 0.0
        if self.rel_tokens is not None:
            num = 0
            for tok_idx, tok in enumerate(response.token_ids):
                if tok in list(self.rel_tokens.values()):
                    token = self.tokenizer.convert_ids_to_tokens(tok)
                    if token == "[Relevant]":
                        rel_score += response.logprobs[tok_idx][tok]
                    num += 1
            if num != 0:
                rel_score /= num
            else:
                rel_score = 0.0
        else:
            rel_score = 0.0
        use_score = 0.0
        if self.use_tokens is not None:
            num = 0
            for tok_idx, tok in enumerate(response.token_ids):
                if tok in list(self.use_tokens.values()):
                    token = self.tokenizer.convert_ids_to_tokens(tok)
                    if token == "[Utility:1]":
                        use_score += response.logprobs[tok_idx][tok] * 0.2
                    elif token == "[Utility:2]":
                        use_score += response.logprobs[tok_idx][tok] * 0.2 + 0.2
                    elif token == "[Utility:3]":
                        use_score += response.logprobs[tok_idx][tok] * 0.2 + 0.4
                    elif token == "[Utility:4]":
                        use_score += response.logprobs[tok_idx][tok] * 0.2 + 0.6
                    elif token == "[Utility:5]":
                        use_score += response.logprobs[tok_idx][tok] * 0.2 + 0.8
                    num += 1
            if num != 0:
                use_score /= num
            else:
                use_score = 0.0
        else:
            use_score = 0.0
        if show_details:
            print(f"u:{use_score}, s:{sup_score}, r:{rel_score}")
        score = w_IsUSE * use_score + w_IsSUP * sup_score + w_IsREL * rel_score
        return score
    # 模型对话（面向用户）
    # params
    # query:str 输入
    # return
    # str 输出
    def generate(self, query):
        tmp_response = self.query_llm(self.format_prompt(query))
        result = tmp_response.text
        while(self.need_retrieve(tmp_response)):
            documents = retriever.search(query + "\n" + result, 3)
            tmp_response = self.re_query(query + "\n" + result, documents)
            result += tmp_response.text
        return result


   
            
        
    
        

local_model_path = "self-rag/retrieval_lm/self-rag-model"
# tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
model = Model_LLM(local_model_path, 100, False, 10, tokenizer)

# model = LLM(model=local_model_path, dtype="half")




In [None]:
# def format_prompt(input, paragraph=None):
#     prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
#     if paragraph is not None:
#         prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
#     return prompt

In [None]:
# tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left")
# sampling_params = SamplingParams(temperature=0.0, 
#                                       top_p=1.0, top_k = -1, max_tokens=100, skip_special_tokens=False, 
#                                       logprobs = 10, stop = ["[Retrieval]"])
# ret_tokens, rel_tokens, sup_tokens, use_tokens = load_special_tokens(tokenizer, use_grounding=use_IsSUP, use_utility=use_IsUSE)

In [None]:
# def need_retrieve(response: Response_LLM):
#     if 32001 in response.token_ids:
#         return True
#     return False
# # prompt:输入：list<str>
# def query_llm(model, sampling_params, query):
#     prompt = [query]
#     print(f"查询内容{prompt}")
#     preds = model.generate(prompt, sampling_params)
#     pred_token_ids = preds[0].outputs[0].token_ids
#     pred_text = preds[0].outputs[0].text
#     pred_log_probs = []
#     for logprob in preds[0].outputs[0].logprobs:
#         tmp_log_probs = {}
#         for key, value in logprob.items():
#             tmp_log_probs[key] = np.exp(float(value))
#         pred_log_probs.append(tmp_log_probs)
#     response =  Response_LLM(pred_token_ids, pred_text, pred_log_probs)
#     return response
# # params
# # query:之前查询的问题+之前生成的文本（去除特殊字符 TODO）
# def re_query(model, sampling_params, prompt, documents):
#     max_score = 0
#     best_response = None
#     for document in documents:
#         response = query_llm(model, sampling_params, format_prompt(prompt, document))
#         cur_score = eval_generation(response)
#         print("此次检索的结果如下")
#         print(cur_score)
#         print(response.token_ids)
#         print(response.text)
#         if (cur_score > max_score):
#             max_score = cur_score
#             best_response = response
#     return best_response

# def eval_generation(response:Response_LLM):
#     sup_score = 0.0
#     if sup_tokens is not None:
#         num = 0
#         for tok_idx, tok in enumerate(response.token_ids):
#             if tok in list(sup_tokens.values()):
#                 token = tokenizer.convert_ids_to_tokens(tok)
#                 if token == "[Fully supported]":
#                     sup_score += response.logprobs[tok_idx][tok] * 0.5 + 0.5
#                 if token == "[Partially supported]":
#                     sup_score += response.logprobs[tok_idx][tok] * 0.5
#                 num += 1
#         if num != 0:
#             sup_score /= num
#         else:
#             sup_score = 0.0
#     else:
#         sup_score = 0.0
#     rel_score = 0.0
#     if rel_tokens is not None:
#         num = 0
#         for tok_idx, tok in enumerate(response.token_ids):
#             if tok in list(rel_tokens.values()):
#                 token = tokenizer.convert_ids_to_tokens(tok)
#                 if token == "[Relevant]":
#                     rel_score += response.logprobs[tok_idx][tok]
#                 num += 1
#         if num != 0:
#             rel_score /= num
#         else:
#             rel_score = 0.0
#     else:
#         rel_score = 0.0
#     use_score = 0.0
#     if use_tokens is not None:
#         num = 0
#         for tok_idx, tok in enumerate(response.token_ids):
#             if tok in list(use_tokens.values()):
#                 token = tokenizer.convert_ids_to_tokens(tok)
#                 if token == "[Utility:1]":
#                     use_score += response.logprobs[tok_idx][tok] * 0.2
#                 elif token == "[Utility:2]":
#                     use_score += response.logprobs[tok_idx][tok] * 0.2 + 0.2
#                 elif token == "[Utility:3]":
#                     use_score += response.logprobs[tok_idx][tok] * 0.2 + 0.4
#                 elif token == "[Utility:4]":
#                     use_score += response.logprobs[tok_idx][tok] * 0.2 + 0.6
#                 elif token == "[Utility:5]":
#                     use_score += response.logprobs[tok_idx][tok] * 0.2 + 0.8
#                 num += 1
#         if num != 0:
#             use_score /= num
#         else:
#             use_score = 0.0
#     else:
#         use_score = 0.0
#     print(f"u:{use_score}, s:{sup_score}, r:{rel_score}")
#     score = w_IsUSE * use_score + w_IsSUP * sup_score + w_IsREL * rel_score
#     return score
# def generate(model, sampling_params, query):
#     tmp_response = query_llm(model, sampling_params, format_prompt(query))
#     result = tmp_response.text
#     while(need_retrieve(tmp_response)):
#         documents = retriever.search(query + "\n" + result, 3)
#         tmp_response = re_query(model, sampling_params, query + "\n" + result, documents)
#         result += tmp_response.text
#     return result

In [None]:
query = "Can you tell me the difference between llamas and alpacas?"

In [None]:
# response = generate(model, sampling_params, query)
response = model.generate(query)

In [None]:
print(response)

In [None]:
print(tokenizer)

In [None]:
# prompt = query + response.text

导入查询器

In [None]:
# from passage_retrieval import Retriever

# class Retriever_LLM:
#     def __init__(self, top_k:int = 5):
#         self.retriever = Retriever({})
#         self.retriever.setup_retriever_demo("self-rag/retrieval_lm/contriever", "self-rag/retrieval_lm/enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl", "self-rag/retrieval_lm/enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*",  n_docs=5, save_or_load_index=False)
#     def search(self, query, n_docs:int = 5):
#         retrieved_documents = self.retriever.search_document_demo(query, n_docs)
#         return retrieved_documents

# retriever = Retriever_LLM(5)

将查询结果加入prompt

In [None]:
# def append_prompt(prompt:str, docs):
#     prompts = [format_prompt(prompt, doc["title"] +"\n"+ doc["text"]) for doc in docs]
#     return prompts
    

In [None]:
# documents = retriever.search(query, 10)

In [None]:
# prompts = append_prompt(response.text, documents)

In [None]:
# print(prompts)

In [None]:
# prompts = [format_prompt(query_3, doc["title"] +"\n"+ doc["text"]) for doc in retrieved_documents]
# preds = model.generate(prompts, sampling_params)
# top_doc = retriever.search_document_demo(query_3, 1)[0]
# print("Reference: {0}\nModel prediction: {1}".format(top_doc["title"] + "\n" + top_doc["text"], preds[0].outputs[0].text))

In [None]:
# def format_tokens(response):
#     token_dict = {}
    
#     for i, key in enumerate(response.token_ids):
#         key = tokenizer.convert_ids_to_tokens(id)
#         logprobs = sorted(response.logprobs[0], reverse=True)
#         token_dict[key] = response

In [None]:
# import numpy as np
# def _relevance_score(pred_log_probs) -> float:
#     rel_prob = np.exp(float(pred_log_probs["[Relevant]"]))
#     irel_prob = np.exp(float(pred_log_probs["[Irrelevant]"]))
#     return rel_prob / (rel_prob + irel_prob)

In [None]:
# # 转化为id
# text = '[Retrieval]'
# token_ids = tokenizer.encode(text, max_length = 30, add_special_tokens = True, padding = 'max_length', truncation = True)
# # [101, 791, 1921, 3221, 702, 1962, 1921, 3698, 8024, 2769, 812, 1377, 809, 1139, 1343, 6624, 6624, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 
# tokened_text = tokenizer.convert_ids_to_tokens(32000)
# # ['[CLS]', '今', '天', '是', '个', '好', '天', '气', '，', '我', '们', '可', '以', '出', '去', '走', '走', '。', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']
 

In [None]:
# print(token_ids[-1])

In [None]:
# tokened_text = tokenizer.convert_ids_to_tokens(32000)
# print(tokened_text)