In [1]:
import torch
from models.miniTransformer import generate_square_subsequent_mask
from models.miniTransformerV2 import TransformerChat
import myTokenizer



In [2]:
from myTokenizer import myTokenizer
ThisTokenizer = myTokenizer(num_words=15000)

TOKENIZER = ThisTokenizer.load_tokenizer('/tokenizer/tokenizerForHealthCare.pkl')
print("tokenizer done, with length", len(TOKENIZER.word_index) + 1)
print("vocab size:", TOKENIZER.num_words)

# word → index
word2idx = TOKENIZER.word_index

# index → word
idx2word = {idx: word for word, idx in word2idx.items()}

# 注意：word_index 不会自动添加 <pad>，如果你在训练时加了 pad_idx=0，要手动加：
word2idx["<pad>"] = 0
idx2word[0] = "<pad>"

# src_vocab 和 trg_vocab 就是这个 word2idx（如果是共享词表的话）
src_vocab = word2idx
trg_vocab = word2idx

✅ Tokenizer is loaded successfully: /tokenizer/tokenizerForHealthCare.pkl
tokenizer done, with length 41978
vocab size: 15000


In [3]:
SOS_IDX = TOKENIZER.word_index.get("<start>", "<Not found>")
print("Token ID for '<start>':", SOS_IDX)

EOS_IDX = TOKENIZER.word_index.get("<end>", "<Not found>")
print("Token ID for '<end>':", EOS_IDX)

PAD_IDX = 0

Token ID for '<start>': 18
Token ID for '<end>': 19


In [4]:
from models.BiGRU import EncoderBiGRU, DecoderGRU, Seq2SeqGRU, BahdanauAttention

# 假设你已经知道下面这些参数：
vocab_size = TOKENIZER.num_words + 1  
embedding_dim = 256
hidden_dim = 512
pad_idx = 0
output_dim = vocab_size  # 生成任务，输出词表大小和输入相同
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

encoder = EncoderBiGRU(vocab_size, embedding_dim, hidden_dim, pad_idx)
attention = BahdanauAttention(hidden_dim)
decoder = DecoderGRU(vocab_size, embedding_dim, hidden_dim, output_dim, pad_idx, attention)

# Initialize the Seq2Seq model
model = Seq2SeqGRU(encoder, decoder, pad_idx, DEVICE).to(DEVICE)


In [5]:
# 加载保存的参数
model.load_state_dict(torch.load("checkpoint/weight_biGRU_1550.pth", map_location=DEVICE))
model.eval()


Seq2SeqGRU(
  (encoder): EncoderBiGRU(
    (embedding): Embedding(15001, 256, padding_idx=0)
    (gru): GRU(256, 512, batch_first=True, bidirectional=True)
  )
  (decoder): DecoderGRU(
    (embedding): Embedding(15001, 256, padding_idx=0)
    (gru): GRU(1280, 512, batch_first=True)
    (fc_out): Linear(in_features=512, out_features=15001, bias=True)
    (attention): BahdanauAttention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
  )
  (bridge): Linear(in_features=1024, out_features=512, bias=True)
)

In [6]:
import torch
from typing import List

@torch.inference_mode()      # PyTorch ≥1.12，自动关掉梯度
def health_chat(
    sentence: str,
    model: "Seq2SeqGRU",
    tokenizer_src,
    tokenizer_trg,
    device: torch.device,
    max_len: int = 100,
    sos_idx: int = SOS_IDX,
    eos_idx: int = EOS_IDX,
) -> str:
    """
    用训练好的 Seq2SeqGRU 进行推理，返回译文字符串
    --------------------------------------------------
    sentence     : 原始输入句子（str）
    tokenizer_*  : 训练时用的分词器；如果是同一个词表就都传 tokenizer_src
    model        : 加载好权重并切到 model.eval() 的 Seq2SeqGRU
    device       : torch.device("cuda") / torch.device("cpu")
    max_len      : 生成时的最长长度上限
    sos_idx      : <sos> 的 id
    eos_idx      : <eos> 的 id
    """
    model.eval()

    # 1. 预处理 —— 分词 → id → tensor
    src_ids = tokenizer_src.texts_to_sequences([sentence])[0] 
    
    src_ids = [sos_idx] + src_ids + [eos_idx]

    src_tensor = torch.tensor(src_ids, dtype=torch.long, device=device).unsqueeze(0)  # [1, src_len]

    # 2. Encoder
    encoder_outs, enc_hidden = model.encoder(src_tensor)             # [1, src_len, 2H], [1, 2H]  :contentReference[oaicite:0]{index=0}:contentReference[oaicite:1]{index=1}
    dec_hidden = torch.tanh(model.bridge(enc_hidden))                # [1, H]          :contentReference[oaicite:2]{index=2}:contentReference[oaicite:3]{index=3}
    dec_hidden = dec_hidden.unsqueeze(0)                             # [1, 1, H]      :contentReference[oaicite:4]{index=4}:contentReference[oaicite:5]{index=5}

    # 3. Decoder – 逐步生成，greedy search
    trg_indices: List[int] = [sos_idx]                               # 先放 <sos>
    for _ in range(max_len):
        # 上一步输出（或 <sos>）作为当前输入
        last_token = torch.tensor([trg_indices[-1]], device=device)  # [1]
        output, dec_hidden, _ = model.decoder(
            last_token, dec_hidden, encoder_outs
        )                                                            # output: [1, vocab]  :contentReference[oaicite:6]{index=6}:contentReference[oaicite:7]{index=7}

        next_token = int(output.argmax(1))                           # greedy
        trg_indices.append(next_token)

        if next_token == eos_idx:
            break

    # 4. 去掉首尾标记，转回文本
    trg_tokens = [tokenizer_trg.index_word[i] for i in trg_indices[1:-1]]  # id → token
    translation = " ".join(trg_tokens)  
    return translation.strip()


In [7]:
sentence = "i feel like i have persistent knee pain that has been going on for two days."
src_tokens = [src_vocab.get(tok, src_vocab['<UNKNOWN>']) for tok in sentence.split()]
src_tensor = torch.LongTensor(src_tokens)

output = health_chat(sentence, model, TOKENIZER,TOKENIZER, DEVICE)
print("翻译结果：", output)


翻译结果： hi welcome to chat doctor . i read and understood your problem . pain lower back pain is suggestive of musculoskeletal pain or some skin , increasing pain increasing some neck pain increasing musculoskeletal pain increasing increasing some bones are some any further assessment .
