# 대화시스템 실습 1 : TRADE DST 모듈



In [None]:
from tqdm import tqdm
import torch.nn as nn

from convlab2.dst.trade.multiwoz.utils.config import *
from src.TRADE import *
# from src import TRADE

import numpy as np
import shutil, zipfile
from convlab2.util.file_util import cached_path

In [None]:
from src.utils_multiWOZ_DST import *

def download_data(data_url="https://convlab.blob.core.windows.net/convlab-2/trade_multiwoz_data.zip"):
    """Automatically download the pretrained model and necessary data."""
    multiwoz_root = os.path.abspath(os.path.curdir)
    if os.path.exists(os.path.join(multiwoz_root, 'data/multi-woz')) and \
            os.path.exists(os.path.join(multiwoz_root, 'data/dev_dials.json')):
        return
    data_dir = os.path.join(multiwoz_root, 'data')
    if not os.path.exists(data_dir):
        os.mkdir(data_dir)
    zip_file_path = os.path.join(data_dir, 'trade_multiwoz_data.zip')
    if not os.path.exists(os.path.join(data_dir, 'trade_multiwoz_data.zip')):
        print('downloading multiwoz TRADE data files...')
        cached_path(data_url, data_dir)
        files = os.listdir(data_dir)
        target_file = ''
        for name in files:
            if name.endswith('.json'):
                target_file = name[:-5]
        try:
            assert target_file in files
        except Exception as e:
            print('allennlp download file error: TRADE Cross model download failed.')
            raise e
        shutil.copyfile(os.path.join(data_dir, target_file), zip_file_path)
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        print('unzipping data file ...')
        zip_ref.extractall(data_dir)


In [None]:
download_data()

In [None]:
# Model Configuration
early_stop = None
path = None
dataset = 'multiwoz'
decoder = 'TRADE'
hidden_size = 400
batch_size = 32
dropout_rate = 0.2
learning_rate = 0.001
load_embedding = 1
eval_period = 1
gradient_clip = 10
patience = 6
task = "dst"

In [None]:
avg_best, cnt, acc = 0.0, 0, 0.0
train, dev, test, test_special, lang, SLOTS_LIST, gating_dict, max_word = \
    prepare_data_seq(training=True, task=task, sequicity=False, batch_size=batch_size)

In [None]:
model = TRADE(
    hidden_size=hidden_size,
    lang=lang,
    path=path,
    task=task,
    lr=learning_rate,
    dropout=dropout_rate,
    slots=SLOTS_LIST,
    gating_dict=gating_dict,
    nb_train_vocab=max_word)

In [None]:
# Model Learning
for epoch in range(200):
    print("Epoch:{}".format(epoch))
    # Run the train function
    pbar = tqdm(enumerate(train), total=len(train))
    for i, data in pbar:
        model.train_batch(data, int(gradient_clip), SLOTS_LIST[1], reset=(i == 0))
        model.optimize(gradient_clip)
        pbar.set_description(model.print_loss())
        
    if ((epoch + 1) % int(eval_period) == 0):

        acc = model.evaluate(dev, avg_best, SLOTS_LIST[2], early_stop)
        model.scheduler.step(acc)

        if (acc >= avg_best):
            avg_best = acc
            cnt = 0
            best_model = model
        else:
            cnt += 1

        if (cnt == patience or (acc == 1.0 and early_stop == None)):
            print("Ran out of patient, early stop...")
            break 

In [None]:
# Dowload Pre-trained Model and Demo
from convlab2.dst.trade.multiwoz.trade import *
demo_model = MultiWOZTRADE()

In [None]:
user_input = 'i need to book a hotel in the east that has 4 stars .'
demo_model.state['history'] = [['user', user_input]]
state = demo_model.update(user_input)
print(state)

In [None]:
print("=== state.keys() ===")
print(list(state.keys()))

print("\n=== state['request_state'] ===")
print(state['request_state'])

print("\n=== state['belief_state'] ===")
for k in state['belief_state'].keys():
    print("%-20s"%k, state['belief_state'][k])

In [None]:
# Model Evaluation
demo_model.evaluate()