# FunASR-GGUF Pickle Inference Notebook

这个 Notebook 用于从保存的 Embedding Pickle 文件中直接进行推理。

In [None]:
import pickle
import numpy as np
import logging
import ctypes
import time
import os
import sys

# 尝试添加路径以防找不到模块
sys.path.append(os.getcwd())

from llama_cpp import (
    Llama,
    llama_batch_init,
    llama_batch_free,
    llama_decode,
    llama_get_logits,
    llama_kv_self_clear,
)

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    force=True  # 强制重置日志配置，适配 Notebook 环境
)
logger = logging.getLogger(__name__)

In [None]:
# 常量定义
GGUF_MODEL_PATH = r'./model-gguf/qwen3-0.6b-asr.gguf'
MAX_SEQ_LEN = 1024
STOP_TOKEN = [151643, 151645]
MAX_THREADS = 0 # 0 = Auto

In [10]:
def decode_with_pure_embeddings(llm_obj, audio_embeddings, max_new_tokens=200):
    """
    纯 Embedding 解码函数 (复用自 Python 脚本)
    """
    
    # 1. 准备数据
    embeds = audio_embeddings.squeeze()
    if len(embeds.shape) == 1:
        embeds = embeds.reshape(1, -1)
    
    n_tokens, n_dim = embeds.shape
    logger.info(f"注入 Embedding Shape: {embeds.shape}")

    # 2. 初始化 Batch
    batch_embd = llama_batch_init(n_tokens, n_dim, 1)        
    batch_text = llama_batch_init(1, 0, 1)

    ctx = llm_obj.ctx
    
    # 3. 清理上下文缓存
    llama_kv_self_clear(llm_obj.ctx) 
    
    try:
        # ---------------------------------------------------------------------
        # A. 注入 Embedding
        # ---------------------------------------------------------------------
        logger.info("正在注入 Embedding...")
        
        batch_embd.n_tokens = n_tokens
        llm_obj.n_tokens = 0 
        
        # 关键：batch.token 设置为 NULL
        batch_embd.token = ctypes.cast(None, ctypes.POINTER(ctypes.c_int32))

        for i in range(n_tokens):
            batch_embd.pos[i] = i
            batch_embd.n_seq_id[i] = 1
            batch_embd.seq_id[i][0] = 0
            batch_embd.logits[i] = 1 if i == n_tokens - 1 else 0

        if not embeds.flags['C_CONTIGUOUS']:
            embeds = np.ascontiguousarray(embeds)
        
        ctypes.memmove(batch_embd.embd, embeds.ctypes.data, embeds.nbytes)
        
        if llama_decode(ctx, batch_embd) != 0:
             raise RuntimeError("Audio embedding decoding failed")
        
        llm_obj.n_tokens += n_tokens

        # ---------------------------------------------------------------------
        # B. 文本生成
        # ---------------------------------------------------------------------
        generated_text = ""
        logger.info(f"开始生成文本...\n")
        
        eos_token = llm_obj.token_eos()
        vocab_size = llm_obj.n_vocab()
        
        batch_text.n_tokens = 1
        
        gen_start_time = time.time()
        tokens_generated = 0
        
        for step in range(max_new_tokens):
            logits_ptr = llama_get_logits(ctx)
            logits_arr = np.ctypeslib.as_array(logits_ptr, shape=(vocab_size,))
            token_id = int(np.argmax(logits_arr))
            
            if token_id == eos_token or token_id in STOP_TOKEN:
                break
                
            try:
                text_piece = llm_obj.detokenize([token_id]).decode('utf-8', errors='ignore')
                print(text_piece, end="", flush=True)
                generated_text += text_piece
                tokens_generated += 1
            except Exception:
                pass
                
            batch_text.token[0] = token_id
            batch_text.pos[0] = llm_obj.n_tokens
            batch_text.n_seq_id[0] = 1
            batch_text.seq_id[0][0] = 0
            batch_text.logits[0] = 1
            
            if llama_decode(ctx, batch_text) != 0:
                break
            
            llm_obj.n_tokens += 1
            
        print('\n\n\n')
        gen_duration = time.time() - gen_start_time
        tps = tokens_generated / gen_duration if gen_duration > 0 else 0
        logger.info(f"解码速度: {tps:.2f} tokens/s ({tokens_generated} tokens in {gen_duration:.2f}s)\n")
        
    finally:
        llama_batch_free(batch_embd)
        llama_batch_free(batch_text)

    return generated_text

In [None]:
# 在这里加载模型 (只需运行一次))
print(f'Loading GGUF model: {GGUF_MODEL_PATH}')
llm = Llama(
    model_path=GGUF_MODEL_PATH,
    n_ctx=MAX_SEQ_LEN + 1024,
    n_threads=MAX_THREADS,
    embedding=True,
    verbose=False
)
print('GGUF model loaded successfully!')

In [11]:
# === 指定要转录的 Pickle 文件 ===
# 修改下面的路径为你想要测试的文件
TARGET_PICKLE = r'./pickles/embedding_slice_0_160000.pkl' 

# 自动查找最新 (可选)
if not os.path.exists(TARGET_PICKLE) and os.path.exists("pickles"):
    files = [os.path.join("pickles", f) for f in os.listdir("pickles") if f.endswith(".pkl")]
    if files:
        TARGET_PICKLE = max(files, key=os.path.getctime)
        print(f"Auto-selected latest file: {TARGET_PICKLE}")

print(f"Processing: {TARGET_PICKLE}")

if os.path.exists(TARGET_PICKLE):
    with open(TARGET_PICKLE, 'rb') as f:
        embeddings_data = pickle.load(f)
    
    print(f"Loaded embeddings shape: {embeddings_data.shape}")
    print("\n--- Result ---")
    result = decode_with_pure_embeddings(llm, embeddings_data, max_new_tokens=MAX_SEQ_LEN)
    print("\n--- End ---")
else:
    print(f"Error: File {TARGET_PICKLE} not found.")

2026-01-19 17:40:46,746 - INFO - 注入 Embedding Shape: (155, 1024)
2026-01-19 17:40:46,759 - INFO - 正在注入 Embedding...
init: embeddings required but some input tokens were not marked as outputs -> overriding


Processing: ./pickles/embedding_slice_0_160000.pkl
Loaded embeddings shape: (155, 1024)

--- Result ---


2026-01-19 17:40:47,409 - INFO - 开始生成文本...



，星期日，欢迎收看一千零四期誓言消息，请静静介绍话题。去年十月十九日，九百六十七期节目说到委内瑞拉问题，我们回顾一下你当时的评。





2026-01-19 17:40:48,543 - INFO - 解码速度: 39.69 tokens/s (45 tokens in 1.13s)




--- End ---
