In [1]:
import torch
from models.miniTransformer import generate_square_subsequent_mask
from models.miniTransformerV2 import TransformerChat
import myTokenizer



In [2]:
import torch.nn.functional as F

def decode_with_topk_penalty(model, tokenizer, input_text, k=10, max_len=100, temperature=1.2, penalty=1.5, device='cpu'):
    model.eval()
    start_id = tokenizer.word_index['<start>']
    end_id = tokenizer.word_index['<end>']

    cleaned = myTokenizer.clean_text(input_text)
    input_ids = tokenizer.texts_to_sequences([cleaned])
    input_tensor = torch.tensor(input_ids).to(device)

    with torch.no_grad():
        src_emb = model.pos_encoder(model.embedding(input_tensor))
        memory = model.transformer.encoder(model.norm(src_emb))

    decoder_input = torch.tensor([[start_id]], device=device)

    for _ in range(max_len):
        tgt_emb = model.pos_encoder(model.embedding(decoder_input))
        tgt_mask = generate_square_subsequent_mask(decoder_input.size(1)).to(device)

        with torch.no_grad():
            output = model.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
            logits = model.fc_out(model.norm(output[:, -1, :])) / temperature

        # Repetition penalty
        for tok in set(decoder_input[0].tolist()):
            logits[0, tok] -= penalty

        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_ids = torch.topk(probs, k)
        next_token = topk_ids[0, torch.multinomial(topk_probs[0], 1)]

        decoder_input = torch.cat([decoder_input, next_token.view(1, 1)], dim=1)
        if next_token.item() == end_id:
            break

    output_tokens = decoder_input.squeeze().tolist()[1:]
    return ' '.join([tokenizer.index_word.get(i, '<UNK>') for i in output_tokens if i != end_id])


In [3]:
from myTokenizer import myTokenizer  
# 加载 tokenizer & 模型
ThisTokenizer = myTokenizer(num_words=10000)
tokenizer = ThisTokenizer.load_tokenizer('/tokenizer/tokenizerForMentalHealth.pkl')  # 修改为实际路径
vocab_size = tokenizer.num_words + 1

model = TransformerChat(vocab_size=vocab_size)
model.load_state_dict(torch.load('checkpoint/weight_transformerV2_JSON_1550.pth', map_location='cuda'))  # 修改为你保存的模型路径
model = model.to('cuda')  # 或 'cuda'

response = decode_with_topk_penalty(model, tokenizer, "I feel tired and dizzy", device='cuda')
print("🤖 Bot:", response)

test2 = """My doctor had issue with finding the baby during my transvaginal ultrasound is that because I had to pee she could see the yoke sac and I measured 9 weeks but what she thought was the baby was measuring 7 weeks. I am just worried and want to know if that interfered"""

response = decode_with_topk_penalty(model, tokenizer, test2, device='cuda')
print("🤖 Bot:", response)

✅ Tokenizer is loaded successfully: /tokenizer/tokenizerForMentalHealth.pkl
🤖 Bot: i am down up back from other necessary down <UNKNOWN> terrible nightmare or ? as that i reading anything anything vet back back back her her her up back mindfulness meditation diagnosis differently until next next because because because changed violent back community seriously is an better down them or down terrible them . next next next necessary others fully fully down down dangerous them were next past worse were again wanted into low next next she she felt she assumed he she or down them or or anyone someone they were so it because the next ? next next
🤖 Bot: thanks up those those too those coping tools <UNKNOWN> present or those patterns <UNKNOWN> , steps eight ansiedad story brief insurance nightmare story ? other dreams nightmare whatever action employer de y down breathing j self-esteem pets from deep boundary them him him abusive past past their present them those , . . present those them the p

In [4]:
word2idx = tokenizer.word_index

# show the length of the vocabulary
print(f"Vocabulary size: {len(word2idx)}")
print(f"Number of words: {tokenizer.num_words + 1}")

idx2word = {v: k for k, v in word2idx.items()}

for word, idx in word2idx.items():
    print(f"{word}: {idx}")
    if idx > 15:
        break

token_id = tokenizer.word_index.get("<start>", "<Not found>")
print("Token ID for '<start>':", token_id)

Vocabulary size: 14691
Number of words: 13001
<UNKNOWN>: 1
.: 2
to: 3
,: 4
you: 5
and: 6
i: 7
the: 8
a: 9
is: 10
that: 11
of: 12
your: 13
in: 14
it: 15
are: 16
Token ID for '<start>': 22
