In [1]:
import os
import torch
from transformers import AutoTokenizer

# 如果 modeling.py 在 new_code/pretrain 目录下，这个脚本就放在同一层的 new_code 里，
# 然后这样 import：
from pretrain.modeling import load_model_with_adapter


def build_task_config(chunk_size=500, mem_size=1, compress_ratio=500):
    """
    这里的配置要和你训练时保持一致：
    chunk_size=500, mem_size=1, compress_ratio=500
    """
    task_config = {
        "task_type": "Compress",
        "chunk_size": chunk_size,
        "mem_size": mem_size,
        "compress_ratio": compress_ratio,
        # 下面这些 flag 主要是训练时用的，推理时没太大关系，但保持合理即可
        "is_pretrain": False,
        "is_sft": True,        # 现在是 SFT 后的 QA 模型
        "use_pe": True,
        "use_ae_loss": True,   # 只影响 forward，不影响 ae_inference/lm_inference
        "use_lm_loss": True,
    }
    return task_config


def load_compress_model(
    model_id,
    adapter_path,
    chunk_size=500,
    mem_size=1,
    compress_ratio=500,
    rank=0,
):
    """
    加载：base Llama + 压缩 + LoRA + adapter
    """
    task_config = build_task_config(chunk_size, mem_size, compress_ratio)

    print(f"Loading model [{model_id}] with adapter [{adapter_path}] ...")
    model = load_model_with_adapter(
        model_id=model_id,
        task_config=task_config,
        rank=rank,
        save_path_and_name=adapter_path,
        log=True,
    )
    model.eval()
    tokenizer = model.tokenizer  # modeling 里已经从同一个 model_id 加载过
    return model, tokenizer


def ae_reconstruct(model, tokenizer, context_text, max_new_tokens=512):
    """
    用 ae_inference 复原长上下文：
    - 输入：原始长文本 context_text
    - 输出：模型重建出来的文本
    """
    # 和 pre_pretrain_data 里保持一致：不加 special tokens，手动加 BOS
    context_ids = tokenizer(context_text, add_special_tokens=False)["input_ids"]
    if tokenizer.bos_token_id is not None:
        context_ids = [tokenizer.bos_token_id] + context_ids

    input_ids = torch.LongTensor(context_ids).unsqueeze(0).to(model.device)

    inputs = {"input_ids": input_ids}

    with torch.no_grad():
        gen_ids = model.ae_inference(inputs)

    # ae_inference 返回的是一个 list[int]，直接 decode
    recon_text = tokenizer.decode(gen_ids, skip_special_tokens=True)
    return recon_text


def qa_inference(model, tokenizer, context_text, question, max_new_tokens=256):
    """
    用 lm_inference 做 QA：
    - context 走 encoder + mem
    - question 作为 decoder 的开头 prompt，模型生成 answer
    模板和 instruction_prepare_data.py 里保持一致：
      "### Context:\\n" + context
      "\\n### Question:\\n" + question + "\\n### Answer:\\n"
    """
    # context 部分： [BOS] + "### Context:\n" + context
    context_ids = (
        [tokenizer.bos_token_id]
        + tokenizer("### Context:\n", add_special_tokens=False)["input_ids"]
        + tokenizer(context_text, add_special_tokens=False)["input_ids"]
    )

    # question prompt： "\n### Question:\n" + question + "\n### Answer:\n"
    question_ids = (
        tokenizer("\n### Question:\n", add_special_tokens=False)["input_ids"]
        + tokenizer(question, add_special_tokens=False)["input_ids"]
        + tokenizer("\n### Answer:\n", add_special_tokens=False)["input_ids"]
    )

    input_ids = torch.LongTensor(context_ids).unsqueeze(0).to(model.device)
    lm_targets = torch.LongTensor(question_ids).unsqueeze(0).to(model.device)

    inputs = {
        "input_ids": input_ids,      # 长上下文
        "lm_targets": lm_targets,    # 问题 prompt（不包含答案）
    }

    with torch.no_grad():
        gen_ids = model.lm_inference(inputs, generate_num=max_new_tokens)

    answer_text = tokenizer.decode(gen_ids, skip_special_tokens=True)
    return answer_text


def main():
    # ======= 1. 路径配置，按你的实际环境改 =======
    model_id = "/home/syt/project/Cram/model/model_scope_model/LLM-Research/Llama-3.2-1B-Instruct"
    work_dir = "/home/syt/project/compressor_500/new_code/experiment/llama32_1b_500to1"
    # 如果只做了预训练，用 adapter.pt；做了 SFT，就用 instruction_adapter.pt
    adapter_path = os.path.join(work_dir, "output", "instruction_adapter.pt")

    # ======= 2. 加载模型 =======
    model, tokenizer = load_compress_model(
        model_id=model_id,
        adapter_path=adapter_path,
        chunk_size=500,
        mem_size=1,
        compress_ratio=500,
        rank=0,  # 用 cuda:0
    )

    # ======= 3. 示例：AE 复原 =======
    long_context = (
        "ICAE (Ge et al., 2024) is an autoencoder framework to compress long contexts into short compact memory slots. The method operates by concatenating designated memory tokens to the end of the input sequence before an encoder processes the entire combined sequence. Subsequently, a decoder reconstructs the original sequence using only the information contained within the memory tokens. ICAE is trained in two main phases. It is first pretrained on massive text data using a combination of autoencoding and language modeling objectives, enabling it to generate memory slots that represent the original context. Following pretraining, the model is fine-tuned on instruction data for the purpose of producing desirable responses to various prompts. An overview of the ICAE framework is shown in Figure 3."
    )
    recon_text = ae_reconstruct(model, tokenizer, long_context)
    print("=== AE Reconstruction ===")
    print(recon_text)
    print("\n")

    # ======= 4. 示例：QA 推理 =======
    context_text = (
        "ICAE (Ge et al., 2024) is an autoencoder framework to compress long contexts into short compact memory slots. The method operates by concatenating designated memory tokens to the end of the input sequence before an encoder processes the entire combined sequence. Subsequently, a decoder reconstructs the original sequence using only the information contained within the memory tokens. ICAE is trained in two main phases. It is first pretrained on massive text data using a combination of autoencoding and language modeling objectives, enabling it to generate memory slots that represent the original context. Following pretraining, the model is fine-tuned on instruction data for the purpose of producing desirable responses to various prompts. An overview of the ICAE framework is shown in Figure 3."
    )
    question = "what is ICAE?"
    answer = qa_inference(model, tokenizer, context_text, question, max_new_tokens=128)
    print("=== QA Answer ===")
    print(answer)


if __name__ == "__main__":
    main()


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'pretrain.modeling'