In [None]:
import torch
from models.miniTransformer import generate_square_subsequent_mask, TransformerChat
import myTokenizer

def greedy_decode(model, tokenizer, input_text, max_output_len=150, device='cpu'):
    model.eval()

    # Step 1: 清洗并编码输入文本
    cleaned = myTokenizer.clean_text(input_text)
    print(f"cleaned: {cleaned}")
    input_seq = tokenizer.texts_to_sequences([cleaned])
    print(f"input_seq: {input_seq}")
    input_tensor = torch.tensor(input_seq).to(device)

    # Step 2: 准备 decoder 输入（以 <start> 开头）
    start_token_id = tokenizer.word_index.get('<start>', 1)
    print(f"start_token_id: {start_token_id}")
    end_token_id = tokenizer.word_index.get('<end>', 2)
    decoder_input = torch.tensor([[start_token_id]], device=device)

    # Step 3: 编码器输出
    with torch.no_grad():
        src_emb = model.pos_encoder(model.src_embedding(input_tensor))
        memory = model.transformer.encoder(src_emb)

    # Step 4: 逐步生成 token
    for _ in range(max_output_len):
        tgt_emb = model.pos_encoder(model.tgt_embedding(decoder_input))
        tgt_mask = generate_square_subsequent_mask(decoder_input.size(1)).to(device)

        # with torch.no_grad():
        #     output = model.transformer.decoder(
        #         tgt_emb, memory, tgt_mask=tgt_mask
        #     )
        #     logits = model.fc_out(output[:, -1, :])  # 最后一个 token 的输出
        #     next_token = logits.argmax(dim=-1).unsqueeze(0)

        with torch.no_grad():
            output = model.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
            logits = model.fc_out(output[:, -1, :])
            # print(logits)
            # 惩罚重复 <start>（避免死循环）
            if decoder_input[0, -1].item() == start_token_id:
                logits[0, start_token_id] -= 1.0

            next_token = logits.argmax(dim=-1).unsqueeze(0)

        decoder_input = torch.cat([decoder_input, next_token], dim=1)

        if next_token.item() == end_token_id:
            break

    # Step 5: 解码输出为文本
    output_tokens = decoder_input.squeeze().tolist()[1:]  # 去掉 <start>
    words = [tokenizer.index_word.get(tok, '<UNK>') for tok in output_tokens if tok != end_token_id]
    return ' '.join(words)


In [7]:
def beam_search_decode(model, tokenizer, input_text, beam_width=5, max_output_len=100, device='cuda'):
    model.eval()

    start_token = tokenizer.word_index.get('<start>')
    end_token = tokenizer.word_index.get('<end>')
    if start_token is None or end_token is None:
        raise ValueError("Tokenizer must contain <start> and <end> tokens.")

    # Step 1: 预处理输入
    cleaned = myTokenizer.clean_text(input_text)
    input_seq = tokenizer.texts_to_sequences([cleaned])
    input_tensor = torch.tensor(input_seq).to(device)

    with torch.no_grad():
        src_emb = model.pos_encoder(model.src_embedding(input_tensor))
        memory = model.transformer.encoder(src_emb)

    # 初始beam：[(句子token列表, score)]
    beams = [(torch.tensor([[start_token]], device=device), 0.0)]

    for _ in range(max_output_len):
        new_beams = []
        for seq, score in beams:
            if seq[0, -1].item() == end_token:
                new_beams.append((seq, score))  # 已完成
                continue

            tgt_emb = model.pos_encoder(model.tgt_embedding(seq))
            tgt_mask = generate_square_subsequent_mask(seq.size(1)).to(device)

            with torch.no_grad():
                output = model.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
                logits = model.fc_out(output[:, -1, :])
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

            topk_log_probs, topk_indices = torch.topk(log_probs, beam_width)

            for log_prob, idx in zip(topk_log_probs[0], topk_indices[0]):
                next_seq = torch.cat([seq, idx.view(1, 1)], dim=1)
                new_beams.append((next_seq, score + log_prob.item()))

        # 选出分数最高的前 k 个
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

        # 如果所有 beam 都结束了，提前终止
        if all(seq[0, -1].item() == end_token for seq, _ in beams):
            break

    # 选得分最高的一条
    best_seq = beams[0][0].squeeze().tolist()

    # 去掉 <start> 和 <end>
    decoded = [
        tokenizer.index_word.get(tok, '<UNK>')
        for tok in best_seq
        if tok not in [start_token, end_token]
    ]
    return ' '.join(decoded)


In [3]:
import torch
import torch.nn.functional as F
import random

def top_k_sampling_decode(model, tokenizer, input_text, k=10, max_output_len=100, device='cpu'):
    model.eval()

    start_token = tokenizer.word_index.get('<start>')
    end_token = tokenizer.word_index.get('<end>')
    # if start_token is None or end_token is None:
    #     raise ValueError("Tokenizer must contain <start> and <end> tokens.")

    # Step 1: 清洗 + 编码输入
    cleaned = myTokenizer.clean_text(input_text)
    input_seq = tokenizer.texts_to_sequences([cleaned])
    input_tensor = torch.tensor(input_seq).to(device)

    with torch.no_grad():
        src_emb = model.pos_encoder(model.src_embedding(input_tensor))
        memory = model.transformer.encoder(src_emb)

    # Step 2: 初始化 decoder 输入
    decoder_input = torch.tensor([[start_token]], device=device)

    for _ in range(max_output_len):
        tgt_emb = model.pos_encoder(model.tgt_embedding(decoder_input))
        tgt_mask = generate_square_subsequent_mask(decoder_input.size(1)).to(device)

        with torch.no_grad():
            output = model.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
            logits = model.fc_out(output[:, -1, :])  # 取最后一个 token 的 logits
            # logits = model.fc_out(output[:, -1, :])  # 原始 logits

            # 🔁 Repetition penalty
            repetition_penalty = 1.5
            used_tokens = set(decoder_input[0].tolist())
            for token_id in used_tokens:
                logits[0, token_id] -= repetition_penalty

            # probs = F.softmax(logits, dim=-1)

        probs = F.softmax(logits, dim=-1)

        # Step 3: top-k 采样
        topk_probs, topk_indices = torch.topk(probs, k)
        topk_probs = topk_probs.squeeze()
        topk_indices = topk_indices.squeeze()

        print("Topk tokens:", [tokenizer.index_word.get(i.item()) for i in topk_indices])
        print("Topk probs:", topk_probs.tolist())

        # 从 top-k 中随机选一个
        sampled_index = torch.multinomial(topk_probs, 1).item()
        next_token = topk_indices[sampled_index]

        decoder_input = torch.cat([decoder_input, next_token.view(1, 1)], dim=1)

        if next_token.item() == end_token:
            break

    # 解码为文本
    output_tokens = decoder_input.squeeze().tolist()[1:]  # 去掉 <start>
    decoded = [
        tokenizer.index_word.get(tok, '<UNK>')
        for tok in output_tokens if tok != end_token
    ]
    return ' '.join(decoded)


In [9]:
from myTokenizer import myTokenizer  
# 加载 tokenizer & 模型
ThisTokenizer = myTokenizer(num_words=10000)
tokenizer = ThisTokenizer.load_tokenizer('/tokenizer/tokenizerForHealthCare.pkl')  # 修改为实际路径
vocab_size = tokenizer.num_words + 1

model = TransformerChat(input_vocab_size=vocab_size, target_vocab_size=vocab_size)
model.load_state_dict(torch.load('checkpoint/weight_transformer_3550.pth', map_location='cuda'))  # 修改为你保存的模型路径
model = model.to('cuda')  # 或 'cuda'

# 进行推理
# response = greedy_decode(model, tokenizer, "I feel tired and dizzy", device='cuda')
response = beam_search_decode(model, tokenizer, "MY HEAD FEEL TERRIBALE", beam_width=5, device='cuda')
# response = top_k_sampling_decode(
#     model, tokenizer,
#     input_text="I feel tired and dizzy",
#     k=10,
#     device='cuda'  # 或 'cpu'
# )
print("🤖 Bot:", response)


✅ Tokenizer is loaded successfully: /tokenizer/tokenizerForHealthCare.pkl
🤖 Bot: welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome welcome the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the


In [5]:
word2idx = tokenizer.word_index

# show the length of the vocabulary
print(f"Vocabulary size: {len(word2idx)}")
print(f"Number of words: {tokenizer.num_words + 1}")

idx2word = {v: k for k, v in word2idx.items()}

for word, idx in word2idx.items():
    print(f"{word}: {idx}")
    if idx > 15:
        break

token_id = tokenizer.word_index.get("<start>", "<Not found>")
print("Token ID for '<start>':", token_id)

Vocabulary size: 41977
Number of words: 15001
<UNKNOWN>: 1
.: 2
,: 3
i: 4
and: 5
the: 6
to: 7
a: 8
is: 9
of: 10
you: 11
in: 12
for: 13
it: 14
your: 15
my: 16
Token ID for '<start>': 18
