## EmpDG 交互式对话示例

下面演示如何在 Notebook 中加载已训练好的共情对话模型，并与之进行实时对话。

In [1]:
# 导入必要库
import sys
# sys.argv = ['', '--model', 'EmpDG', '--cuda', '--label_smoothing', '--noam', '--emb_dim', '300', '--rnn_hidden_dim', '300', '--hidden_dim', '300', '--hop', '1', '--heads', '2', '--pretrain_emb', '--device_id', '0', '--save_path', 'results/tb_results/EmpDG/', '--pointer_gen', '--emb_file', './vectors/glove.6B.300d.txt']
sys.argv = ['', '--model', 'EmpDG', '--label_smoothing', '--noam', '--emb_dim', '300', '--rnn_hidden_dim', '300', '--hidden_dim', '300', '--hop', '1', '--heads', '2', '--pretrain_emb', '--device_id', '0', '--save_path', 'results/tb_results/EmpDG/', '--pointer_gen', '--emb_file', './vectors/glove.6B.300d.txt']
import torch
import nltk
from collections import deque
from utils import config
from utils.data_loader import prepare_data_seq
from utils.data_reader import Lang
from Model.transformer import Transformer
from Model.EmoPrepend import EmoP
from Model.EmpDG_G import EmpDG_G
from interact import make_batch

                                      Opts                                      
--------------------------------------------------------------------------------
                                dataset: empathetic                             
                             hidden_dim: 300                                    
                                emb_dim: 300                                    
                             batch_size: 16                                     
                                 epochs: 10                                     
                                     lr: 0.0001                                 
                          max_grad_norm: 2.0                                    
                              beam_size: 5                                      
                              save_path: save/EmpDG/                            
                      save_path_dataset: save/                                  
                            

In [2]:
# 加载数据和模型
model_name = "EmpDG"   # 可替换为 "Transformer", "EmoPrepend", "EmpDG_woG" 等

# 准备数据迭代器和词表
data_loader_tr, data_loader_val, data_loader_tst, vocab, program_number = \
    prepare_data_seq(batch_size=config.batch_size)

# 构造模型实例
if model_name == "Transformer":
    model = Transformer(vocab, decoder_number=program_number)
elif model_name in ("EmoPrepend", "EmpDG_woG"):
    model = EmoP(vocab, decoder_number=program_number)
else:  # EmpDG 或 EmpDG_woD
    model = EmpDG_G(vocab, emotion_number=program_number)

# 加载checkpoint
ckpt = torch.load(f"result/{model_name}_best.tar", map_location=lambda s, t: s)
if model_name in ("EmpDG", "EmpDG_woG"):
    weights = ckpt["models_g"]
else:
    weights = ckpt["models"]
model.load_state_dict({k: weights[k] for k in weights})
model.to(config.device)
model.eval()
print(f"{model_name} 模型加载完毕，开始交互。")

LOADING empathetic_dialogue ...


08-11 15:49 Vocab  22359 


[situation]: i spent all weekend working on my truck to fix a miss in the engine . despite spending over $ 200 in parts , it did not do a thing to fix the miss .
[emotion]: angry
[context]: ['i worked on my truck all weekend . spent $ 215 on parts and a special tool . still did not fix the miss in the engine .']
[emotion context]: truck spent parts special fix miss engine
[target]: what is the problem with the engine ?
[feedback]: it has a miss . sometimes when you are driving it sputters and jerks and barely has any power . other times it drives just fine . i have replaced so many things on it and nothing makes a difference for very long .
 
[situation]: a few years ago , my marriage broke up , and i found myself living alone for the first time in my life . though i eventually grew accustomed to the solitude , it took a while to get used to it .
[emotion]: lonely
[context]: ['i found myself divorced a few years ago , and for the first time in my life , i was living alone .']
[emotion 

  ckpt = torch.load(f"result/{model_name}_best.tar", map_location=lambda s, t: s)


In [3]:
# 定义对话函数
DIALOG_SIZE = 5
context = deque(["None"] * DIALOG_SIZE, maxlen=DIALOG_SIZE)

def chat_once(usr_utt: str) -> str:
    """向模型发送一句话，返回模型回复。"""
    context.append(usr_utt)
    batch = make_batch(context, vocab)
    reply = model.decoder_greedy(batch, max_dec_step=100)[0]
    context.append(reply)
    return reply

# 进入对话循环
print("输入'q'退出对话。")
while True:
    usr = input(">> User: ").strip()
    if usr.lower() == 'q':
        print("退出对话。")
        break
    resp = chat_once(usr)
    print(f"{model_name}: {resp}")

输入'q'退出对话。
EmpDG: hi 
EmpDG: oh wow , i am so happy for you ! 
EmpDG: i am so sorry . i am so sorry . i am not a fan of a few years ago . 
退出对话。
