In [1]:
import json
import itertools
import random
import os

intent_ontology = json.load(open("resources/intent_ontology.json"))
slot_ontology = json.load(open("resources/slot_ontology.json"))
slot_ontology_backup = json.load(open("resources/slot_ontology.json"))
available_slots_combinations = json.load(open("resources/available_slots_combinations.json"))

In [2]:
def _get_slot_value_pairs(slot):
    pairs = []
    if slot_ontology[slot]["type"] == "String":
        pairs.append((slot, "StringValue")) # 後でsample_valueに書き換える
    pairs += [(slot, value) for value in slot_ontology[slot]["possible_values"]+["?"]]
    return pairs

def _sample_string_value(slot, value):
    if value == "StringValue":
        if not slot_ontology[slot]["sample_values"]:
            slot_ontology[slot]["sample_values"] = slot_ontology_backup[slot]["sample_values"].copy()
        if slot_ontology[slot]["sample_values"]:
            idx = random.randint(0, len(slot_ontology[slot]["sample_values"])-1)
            value = slot_ontology[slot]["sample_values"].pop(idx)
        else:
            # user_other
            value = None
    return value

def get_slot_value_pairs_combinations(required_slots, optional_slots, min_optional_slots, max_optional_slots):
    # optional slotsの準備
    optional_slots_ = []
    for optional_slot in optional_slots:
        if optional_slot.endswith("_*"):
            optional_slots_ += [slot for slot in slot_ontology if slot.startswith(optional_slot[:-1])]
        else:
            optional_slots_ += [optional_slot]

    max_optional_slots = min(max_optional_slots, len(optional_slots_))
    assert min_optional_slots <= max_optional_slots
    # required slots + optional slotsにおける，slotの組み合わせを総当たり生成
    slots_combinations = []
    for num_slots in range(min_optional_slots, max_optional_slots+1):
        for optional_slots_comb in itertools.combinations(optional_slots_, num_slots):
            slots_combinations.append(required_slots + list(optional_slots_comb))

    slot_value_pairs_combinations = []
    for slots_comb in slots_combinations:
        try:
            slot_value_pairs_list = [_get_slot_value_pairs(slot) for slot in slots_comb]
        except KeyError:
            raise KeyError(f"required_slots: {required_slots}\toptional_slots: {optional_slots}\tslots_comb: {slots_comb}")
        # slot_value_pairs_combinations += [list(slot_value_pairs) for slot_value_pairs in itertools.product(*slot_value_pairs_list)]
        slot_value_pairs_combinations += list(itertools.product(*slot_value_pairs_list))

    slot_value_pairs_w_sample_value_combinations = []
    for slot_value_pairs in slot_value_pairs_combinations:
        slot_value_pairs_w_sample_value = []
        for slot, value in slot_value_pairs:
            value = _sample_string_value(slot=slot, value=value)
            slot_value_pairs_w_sample_value.append((slot, value))
        slot_value_pairs_w_sample_value_combinations.append(slot_value_pairs_w_sample_value)
    return slot_value_pairs_w_sample_value_combinations

def is_possible_da(speaker, intent, slot_value_pairs):
    if speaker == "system":
        for slot, value in slot_value_pairs:
            if intent == "request":
                if slot == "how_to_call_user" and value == "?":
                    return False
                if slot != "how_to_call_user" and value != "?":
                    return False
            else:
                if value == "?":
                    return False
                if slot == "user_name":
                    # ユーザの名前はシステムからは言わない
                    return False

        if intent_ontology[speaker][intent]["min_optional_slots"]:
            # min optional slotが1つ以上のものについて，
            # how_to_call_userしかslotがない場合は除外
            for slot, value in slot_value_pairs:
                if slot != "how_to_call_user":
                    break
            else:
                return False

    elif speaker == "user":
        for slot, value in slot_value_pairs:
            if intent == "request":
                if value != "?":
                    return False
            else:
                if value == "?":
                    return False

    else:
        raise ValueError(speaker)

    slots_comb = set([slot for slot, _ in slot_value_pairs])
    for avail_slots_comb in available_slots_combinations.values():
        if slots_comb.issubset(set(avail_slots_comb)):
            break
    else:
        return False
    
    for slot, value in slot_value_pairs:
        if value in intent_ontology[speaker][intent]["impossible_values"]:
            return False
    return True

def to_jpda(da):
    id_ = da["id"]
    speaker = da["speaker"]
    intent = da["intent"]
    slots = da["slots"]
    jp_da = {
        "id": id_,
        "speaker": speaker,
        "intent": intent_ontology[speaker][intent]["jp"],
        "slots": {slot_ontology[slot]["jp"]: value for slot, value in slots.items()}
    }
    return jp_da

def split_da_list(da_list, task_batch_size=20, group_batch_size=50, shuffle=True):
    if shuffle:
        da_list = random.sample(da_list, k=len(da_list))
    groups = [[]]
    for i in range(0, len(da_list), task_batch_size):
        batch = da_list[i:i+task_batch_size]
        if len(groups[-1]) >= group_batch_size:
            groups.append([])
        groups[-1].append(batch)
    return groups

In [3]:
speaker = "system"
max_optional_slots = 2
task_batch_size = 20 # 1タスクにつき20個のDA
group_batch_size = 50 # 1グループにつき50タスク
count = 0
da_list = []
jpda_list = []
for intent in intent_ontology[speaker]:
    slot_value_pairs_combinations = get_slot_value_pairs_combinations(required_slots=intent_ontology[speaker][intent]["required_slots"],
                                                                      optional_slots=intent_ontology[speaker][intent]["optional_slots"],
                                                                      min_optional_slots=intent_ontology[speaker][intent]["min_optional_slots"],
                                                                      max_optional_slots=max_optional_slots)
    for slot_value_pairs in slot_value_pairs_combinations:
        if is_possible_da(speaker=speaker, intent=intent, slot_value_pairs=slot_value_pairs):
            da = {
                "id": speaker[0] + str(count).zfill(8),
                "speaker": speaker,
                "intent": intent,
                "slots": dict(slot_value_pairs)
            }
            count += 1
        else:
            continue
        da_list.append(da)
        jpda_list.append(to_jpda(da))

da_list_dpath = f"da_list-{speaker}-max{max_optional_slots}"
os.makedirs(da_list_dpath, exist_ok=True)
config = {
    "speaker": speaker,
    "max_optional_slots": max_optional_slots,
    "task_batch_size": task_batch_size,
    "group_batch_size": group_batch_size,
    "total_da": len(da_list)
}
json.dump(config, open(f"{da_list_dpath}/meta_info.json", "w"), indent=4)
json.dump(jpda_list, open(f"{da_list_dpath}/all.json", "w"), indent=4, ensure_ascii=False)
groups = split_da_list(da_list=jpda_list, task_batch_size=task_batch_size, group_batch_size=group_batch_size)
for g_i, group in enumerate(groups):
    group_dpath = f"{da_list_dpath}/group{g_i}"
    os.makedirs(group_dpath, exist_ok=True)
    for t_i, task in enumerate(group):
        json.dump(task, open(f"{group_dpath}/task{t_i}.json", "w"), indent=4, ensure_ascii=False)