In [None]:
from infer import get_model, infer, format_context

from transformers import AutoTokenizer
from transformers import PreTrainedTokenizer


import copy
import warnings

import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import skip_init
from typing import Optional, List, Callable, Dict, Any

from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig

from transformers.utils import logging
logger = logging.get_logger(__name__)

model_name: str = "THUDM/chatglm-6b"
peft_path: str = "silk-road/luotuo-qa-lora-0.1"
model_revision: str = "969290547e761b20fdb96b0602b4fd8d863bbb85"
with_origin_model: bool = True

model = get_model(model_name, peft_path)
origin_model = None
if with_origin_model:
    origin_model = get_model(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, revision = model_revision)

In [None]:
story = '''
长妈妈曾经讲给我一个故事听：先前，有一个读书人住在古庙里用功，晚间， 在院子里纳凉的时候，突然听到有人在叫他。答应着，四面看时，却见一个美女的 脸露在墙头上，向他一笑，隐去了。他很高兴；但竟给那走来夜谈的老和尚识破了 机关。说他脸上有些妖气，一定遇见“美女蛇”了；这是人首蛇身的怪物，能唤人 名，倘一答应，夜间便要来吃这人的肉的。他自然吓得要死，而那老和尚却道无妨 ，给他一个小盒子，说只要放在枕边，便可高枕而卧。他虽然照样办，却总是睡不 着，——当然睡不着的。到半夜，果然来了，沙沙沙！门外象是风雨声。他正抖作 一团时，却听得豁的一声，一道金光从枕边飞出，外面便什么声音也没有了，那金 光也就飞回来，敛在盒子里。后来呢？后来，老和尚说，这是飞蜈蚣，它能吸蛇的 脑髓，美女蛇就被它治死了。
'''
question = '是谁识破了机关？'


In [None]:

@torch.no_grad()
def continue_generate(
        model,
        input_ids: torch.Tensor,
        append_ids: torch.Tensor,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        **kwargs,
):
    batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]

    if generation_config is None:
        generation_config = model.generation_config
    generation_config = copy.deepcopy(generation_config)
    model_kwargs = generation_config.update(**kwargs)
    bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id

    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]

    has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
    if has_default_max_length and generation_config.max_new_tokens is None:
        warnings.warn(
            f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
            "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
            " recommend using `max_new_tokens` to control the maximum length of the generation.",
            UserWarning,
        )
    elif generation_config.max_new_tokens is not None:
        generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
        if not has_default_max_length:
            logger.warn(
                f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                "Please refer to the documentation for more information. "
                "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
                UserWarning,
            )

    if input_ids_seq_length >= generation_config.max_length:
        input_ids_string = "decoder_input_ids" if model.config.is_encoder_decoder else "input_ids"
        logger.warning(
            f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
            f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
            " increasing `max_new_tokens`."
        )

    # 2. Set generation parameters if not already defined
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

    logits_processor = model._get_logits_processor(
        generation_config=generation_config,
        input_ids_seq_length=input_ids_seq_length,
        encoder_input_ids=input_ids,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
        logits_processor=logits_processor,
    )

    stopping_criteria = model._get_stopping_criteria(
        generation_config=generation_config, stopping_criteria=stopping_criteria
    )
    logits_warper = model._get_logits_warper(generation_config)
    
    input_ids = torch.cat([input_ids, append_ids], dim=-1)
    unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
    scores = None
    while True:
        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
        # forward pass to get next token
        outputs = model(
            **model_inputs,
            return_dict=True,
            output_attentions=False,
            output_hidden_states=False,
        )

        next_token_logits = outputs.logits[:, -1, :]

        # pre-process distribution
        next_token_scores = logits_processor(input_ids, next_token_logits)
        next_token_scores = logits_warper(input_ids, next_token_scores)

        # sample
        probs = nn.functional.softmax(next_token_scores, dim=-1)
        if generation_config.do_sample:
            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            next_tokens = torch.argmax(probs, dim=-1)

        # update generated ids, model inputs, and length for next step
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        model_kwargs = model._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
        )
        unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())

        # stop when each sentence is finished, or if we exceed the maximum length
        if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
            break
    return input_ids

def question_answer_infer(model, tokenizer: PreTrainedTokenizer, story, question, max_length=2048):
    append_text = f"""问题转义为:{question}
答案为:"""

    input_token_ids = tokenizer.encode(format_context(story, question))
    input_ids = torch.LongTensor([input_token_ids]).to(model.device)
    append_token_ids = tokenizer.encode(append_text)
    append_ids = torch.LongTensor([append_token_ids]).to(model.device)
    out = continue_generate(model, input_ids, append_ids, 
        max_length = max_length, 
        do_sample=True, 
        top_p=0.2, 
        temperature=0.95, 
        logits_processor=None,
    )[0]
    out_text = tokenizer.decode(list[int](out)[len(input_token_ids) + len(append_token_ids):])
    answer = out_text.replace("\nEND", "").strip()
    print(f"question_answer_infer: ###{answer}###")

    from infer import gen
    gen_out = gen(model, tokenizer, format_context(story, question) + append_text)
    print(f"default gen: ###{gen_out}###")
    
    return answer

for i in range(10):
    question_answer_infer(model, tokenizer, story, question)