In [18]:
import json
from transformers import pipeline
from collections import defaultdict
import re

In [19]:
data_path = 'MultiWOZ2.4-main/data/mwz24/MULTIWOZ2.4/data.json'
dialogue_acts_path = 'MultiWOZ2.4-main/data/mwz24/MULTIWOZ2.4/dialogue_acts.json'
test_list_path = 'MultiWOZ2.4-main/data/mwz24/MULTIWOZ2.4/testListFile.json'

In [20]:
with open(test_list_path) as f:
    test_list = [line.strip() for line in f]
with open(data_path) as f:
    data = json.load(f)
with open(dialogue_acts_path) as f:
    dialogue_acts = json.load(f)

In [21]:
test_data = [data[dialogue_id] for dialogue_id in test_list]

In [22]:
system_acts_classifier = pipeline('text-classification', model='model/system_acts_model', device=0, top_k=None)

In [23]:
def compare_lists(pred, true):
    # Tạo các tập hợp từ danh sách
    set1 = set(pred)
    set2 = set(true)

    # Phân loại các phần tử
    common_elements = set1 & set2  # Các phần tử chung (có trong cả hai list)
    only_in_pred = set1 - set2    # Phần tử chỉ có trong pred
    only_in_true = set2 - set1    # Phần tử chỉ có trong true

    # Kết quả được tô màu
    result = []

    # Xử lý phần tử trong cả hai danh sách (màu xanh)
    for item in common_elements:
        result.append(f"\033[94m{item} (matched)\033[0m")  # Màu xanh dương

    # Xử lý phần tử chỉ có trong danh sách 1 (màu đỏ)
    for item in only_in_pred:
        result.append(f"\033[91m{item} (only in pred)\033[0m")  # Màu đỏ

    # Xử lý phần tử chỉ có trong danh sách 2 (màu đỏ)
    for item in only_in_true:
        result.append(f"\033[91m{item} (only in true)\033[0m")  # Màu đỏ

    return "\n".join(result)

In [24]:
for i in range(100):  
    dialogue_id = test_list[i]
    print(f'Dialogue {dialogue_id}')
    max_turn = 3
    history = []
    sample_text = test_data[i]['log']
    for turn in range(0, len(sample_text) - 1, 2):
        history_text = '[SEP]'.join(history[max(0, turn - 2 * max_turn):turn])
        full_text = f'{history_text}[SEP]{sample_text[turn]}' if history_text else sample_text[turn]['text']
        acts_pred = [out['label'] for out in system_acts_classifier(full_text)[0] if out['score'] > 0.5]
        system_acts = dialogue_acts[dialogue_id[:-5]].get(str(turn//2 + 1), 'No Annotation')
        if system_acts == 'No Annotation':
            system_acts = ['No Annotation']
        else:
            system_acts = list(system_acts.keys())
        print(sample_text[turn]['text'])
        print(compare_lists(acts_pred, system_acts))
        print(sample_text[turn + 1]['text'])
    print('------------------------------------------------')

Dialogue MUL0484.json
I need train reservations from norwich to cambridge
[94mTrain-Request (matched)[0m
[91mTrain-Inform (only in true)[0m
I have 133 trains matching your request. Is there a specific day and time you would like to travel?
I'd like to leave on Monday and arrive by 18:00.
[91mTrain-Request (only in pred)[0m
[91mTrain-OfferBook (only in true)[0m
[91mTrain-Inform (only in true)[0m
There are 12 trains for the day and time you request.  Would you like to book it now?
Before booking, I would also like to know the travel time, price, and departure time please.
[94mTrain-Inform (matched)[0m
[91mTrain-Request (only in pred)[0m
[91mTrain-OfferBook (only in true)[0m
There are 12 trains meeting your needs with the first leaving at 05:16 and the last one leaving at 16:16. Do you want to book one of these? 
No hold off on booking for now.  Can you help me find an attraction called cineworld cinema?
[94mAttraction-Inform (matched)[0m
[94mgeneral-reqmore (matched)[

In [25]:
categorical_value_clf = pipeline('text-classification', model='model/categorical_value_model', device=0, top_k=None)
non_categorical_value_tclf = pipeline('ner', model='model/non_categorical_value_model', device=0)

def slot_vlaue_predict(full_text):
    state = defaultdict(list)
    categorical_value_result = categorical_value_clf(full_text)
    non_categorical_value_result = non_categorical_value_tclf(full_text)
    
    for out in categorical_value_result[0]:
        if out['score'] > 0.5:
            slot, value = out['label'].split('=')
            state[slot].append(value)
            
    current_entity = None
    current_value = ""
    
    for item in non_categorical_value_result:
        entity_type = item['entity'][2:]  # Remove the B- or I- prefix
        if item['entity'].startswith('B-'):
            if current_entity:  # Save the previous entity-value pair if exists
                if current_value.find(':') != -1:
                    current_value = current_value.replace(' ', '')
                state[current_entity].append(current_value)
            current_entity = entity_type
            current_value = item['word']
        elif item['entity'].startswith('I-') and current_entity == entity_type:
            if item['word'].startswith('##'):
                current_value += item['word'][2:]
            else:
                current_value += ' ' + item['word']  # Concatenate words for the same entity

    # Append the last entity-value pair
    if current_entity:
        if current_value.find(':') != -1:
            current_value = current_value.replace(' ', '')
        state[current_entity].append(current_value)
        
    return state

In [26]:
def get_slot_value(metadata):
    slot_values = {}
    for domain, domain_value in metadata.items():
        for slot, value in domain_value['book'].items():
            if slot == 'booked':
                continue
            if value and value != 'not mentioned':
                slot_values[f'{domain}-book {slot}'] = re.split(r'\||>', value)
        for slot, value in domain_value['semi'].items():
            if value and value != 'not mentioned':
                slot_values[f'{domain}-{slot}'] = re.split(r'\||>', value)
    return slot_values

In [27]:
def compare_dicts(pred_state, true_state):
    # Mã ANSI cho màu sắc
    GREEN = "\033[92m"  # Màu xanh lá
    RED = "\033[91m"    # Màu đỏ
    RESET = "\033[0m"   # Reset về màu mặc định
    
    result = []
    
    # Tập hợp tất cả các khóa có trong cả hai dict
    all_keys = set(pred_state.keys()).union(true_state.keys())
    
    for key in all_keys:
        if key in pred_state and key in true_state:
            if pred_state[key] == true_state[key]:
                # Nếu giống nhau, in màu xanh
                result.append(f"{GREEN}{key}: {pred_state[key]} (match){RESET}")
            else:
                # Nếu khác nhau, in màu đỏ
                result.append(f"{RED}{key}: {pred_state[key]} != {true_state[key]}{RESET}")
        elif key in pred_state:
            # Chỉ có trong pred_state
            result.append(f"{RED}{key}: {pred_state[key]} (only in pred_state){RESET}")
        elif key in true_state:
            # Chỉ có trong true_state
            result.append(f"{RED}{key}: {true_state[key]} (only in true_state){RESET}")
    
    # Gộp danh sách thành chuỗi, mỗi phần tử trên một dòng
    return "\n".join(result)

In [29]:
for i in range(100): 
    dialogue_id = test_list[i] 
    print(f'Dialogue {dialogue_id}')
    state = {}
    max_turn = 3
    history = []
    sample_text = test_data[i]['log']
    for turn in range(0, len(sample_text) - 1, 2):
        history_text = '[SEP]'.join(history[max(0, turn - 2 * max_turn):turn])
        full_text = f'{history_text}[SEP]{sample_text[turn]}' if history_text else sample_text[turn]['text']
        print(sample_text[turn]['text'])
        state = dict(state | slot_vlaue_predict(full_text))
        print(compare_dicts(state, get_slot_value(sample_text[turn + 1]['metadata'])))
        print(sample_text[turn + 1]['text'])
    print('------------------------------------------------')

Dialogue MUL0484.json
I need train reservations from norwich to cambridge
[92mtrain-destination: ['cambridge'] (match)[0m
[92mtrain-departure: ['norwich'] (match)[0m
I have 133 trains matching your request. Is there a specific day and time you would like to travel?
I'd like to leave on Monday and arrive by 18:00.
[92mtrain-day: ['monday'] (match)[0m
[92mtrain-destination: ['cambridge'] (match)[0m
[92mtrain-departure: ['norwich'] (match)[0m
[92mtrain-arriveBy: ['18:00'] (match)[0m
There are 12 trains for the day and time you request.  Would you like to book it now?
Before booking, I would also like to know the travel time, price, and departure time please.
[92mtrain-day: ['monday'] (match)[0m
[92mtrain-destination: ['cambridge'] (match)[0m
[91mtrain-departure: ['cambridge'] != ['norwich'][0m
[92mtrain-arriveBy: ['18:00'] (match)[0m
There are 12 trains meeting your needs with the first leaving at 05:16 and the last one leaving at 16:16. Do you want to book one of thes