In [1]:
import openpyxl
import json
from collections import Counter
import pandas as pd
import re
from collections import Counter, defaultdict
import random
from copy import deepcopy
import pdb

In [13]:
def load_xlsx_data(filename, sheetname="Sheet1", header=True):
    """ Load xls data
    """
    workbook = openpyxl.load_workbook(filename)
    worksheet = workbook[sheetname]
    headers = []
    contents = []
    for idx, row in enumerate(worksheet.iter_rows()):
        row_values = [cell.value for cell in row]
        if idx == 0 and header:
            headers = row_values
            continue
        if header:
            row_values = dict(zip(headers, row_values))
        contents.append(row_values)
    return contents


RGX_INTENTS = re.compile('intents":(.*?])}]')
RGX_SN = re.compile('"sn":.*?"(.*?)"')
RGX_QUERY = re.compile('"raw_query":.*?"(.*?)"')


def find_dr(target, rgx=RGX_INTENTS):
    res = rgx.findall(target)
    if not res:
        return []
    targ_str = res[0]
    try:
        res = json.loads(targ_str)
    except:
        res = []
    return res
    
    
def find_sn(target, rgx=RGX_SN):
    res = rgx.findall(target)
    if res:
        return res[0]
    return ""


def find_query(target, rgx=RGX_QUERY):
    res = rgx.findall(target)
    if res:
        return res[0]
    return ""


def convert_data(rec):
    rank_order = [x.split("|") for x in rec["rank_order"].split("\n") if x] if rec["rank_order"] else []
    rank_order = [{"domain": x[0], "intent": x[1]} for x in rank_order if len(x) == 2]
    try:
        nlu_model = json.loads(rec["nlu_model_nlp"])
    except:
        nlu_model = {}
    try:
        grammar_body = json.loads(rec["grammar_process_responsebody"])
        gram_intents = grammar_body.get("intents", [])
    except:
        grammar_body = {}
        gram_intents = []
    try:
        model_top1 = json.loads(rec["result"])
    except:
        model_top1 = []
    try:
        dialog_root = json.loads(rec["dialog_root_responsebody"])
        sn = dialog_root["sn"]
        nlu_res = (dialog_root.get("nlu_res") or [{}])[0]
        sorted_intents = nlu_res.get("intents", [])
        raw_query = dialog_root.get("raw_query")
        dr_entities = nlu_res["entities"]
    except:
        sn = find_sn(rec["dialog_root_responsebody"])
        sorted_intents = find_dr(rec["dialog_root_responsebody"])
        raw_query = find_query(rec["dialog_root_responsebody"])
        dr_entities = []
    query = grammar_body.get("query") or raw_query
    entities = grammar_body.get("entities") or dr_entities
    res = {
        "sn": sn,
        "model_topn": nlu_model,
        "query": query,
        "entities": entities,
        "grammar_topn": gram_intents or sorted_intents,
        "rank_order": rank_order,
        "use_nlu_model": rec["should_use_this"],
        "grammar_top1": sorted_intents and [sorted_intents[0]],
        "traceid": rec["p_traceid"],
        "model_top1": model_top1
    }
    return res

In [48]:
raw_data = load_xlsx_data("./MINI-NLU-全量数据(6.15-6.21).xlsx", sheetname="Sheet")

KeyboardInterrupt: 

In [53]:
domain_scope = load_xlsx_data("./domain_scope.xlsx", sheetname="domain.intent")
DOMAINS = set(x["MINI_DOMAIN"] for x in domain_scope if x["MINI_DOMAIN"])
DOMAINS

{'alarm_1',
 'alarm_3',
 'fm_1',
 'fm_2',
 'global_command',
 'music_1',
 'music_2',
 'other',
 'weather_1'}

In [15]:
data = [convert_data(x) for x in raw_data]

In [87]:
def clean_str(target):
    if not target:
        return ""
    res = target.strip().lower()
    res = re.sub(r'^[^\w\s]', '', res)
    res = re.sub(r'[^\w\s]$', '', res)
    return res

print(clean_str(" 今天天气好不好？ "))


AWAKENS = ["小豹", "小雅"]

def all_awakens(target, awakens=AWAKENS):
    if len(target) % 2 != 0:
        return False
    parts = [target[i: i + 2] for i in range(0, len(target), 2)]
    parts = set(parts)
    return all(x in awakens for x in parts)



def get_top1_domain(rec):
    global DOMAINS
    domain_1st = None
    domain_2nd = None
    domain_3rd = None
    for intent in rec.get("grammar_topn", []):
        domain = intent.get("grammar_pkg_name") or intent.get("domain_name") or intent.get("domain")
        if not domain:
            raise ValueError(domain)
        if domain.startswith("music_"):
            domain_1st = "music"
            break
        if domain.startswith("fm_"):
            domain_1st = "fm"
            break
        if domain.startswith("video_"):
            domain_1st = "video"
            break
        if domain == "other" or domain not in DOMAINS:
            domain_3rd = "3rd"
        else:
            domain_2nd = "2nd"
    if domain_1st:
        return domain_1st
    for intent in rec.get("model_topn", []):
        domain = intent["domain"]
        if domain.startswith("music_"):
            domain_1st = "music"
            break
        if domain.startswith("fm_"):
            domain_1st = "fm"
            break
        if domain.startswith("video_"):
            domain_1st = "video"
            break
        if domain == "other" or domain not in DOMAINS:
            if not domain_3rd:
                domain_3rd = "3rd"
        elif not domain_2nd:
            domain_2nd = "2nd"
    if domain_1st:
        return domain_1st
    if domain_2nd:
        return domain_2nd
    return domain_3rd


def sample_records(data, total=2000):
    valid_count = 0
    t2idx = defaultdict(list)
    qcounts = Counter()
    for idx, rec in enumerate(data):
        query = rec.get("query", "")
        norm_query = clean_str(query)
        rec["norm_query"] = norm_query
        if not norm_query:
            t2idx["empty"].append(idx)
            continue
        if all_awakens(norm_query):
            t2idx["awakens"].append(idx)
            continue
        qcounts[norm_query] += 1
        domain_x = get_top1_domain(rec)
        if domain_x is None:
            continue
        t2idx[domain_x].append(idx)
        if domain_x != "3rd":
            valid_count += 1
    samples = []
    aggregates = 0
    currents = set()
    pseudo_domains = ["3rd"]
    pseudo_domains.extend(
        x for x in t2idx if x not in pseudo_domains and x not in ("empty", "awakens"))
    domain_counts = {}
    for idx, domain_x in enumerate(pseudo_domains):
        raw_indexes = t2idx[domain_x]
        indexes = [i for i in raw_indexes if data[i]["norm_query"] not in currents]
        valid_len = len(indexes)
        idxlen = len(raw_indexes)
        num_max = total - aggregates
        if domain_x == "3rd":
            num = int(total * 0.2)
        else:
            num = min(num_max, int(total * .8 * idxlen / valid_count))
        if (idx + 1) == len(pseudo_domains):
            num = max(num, num_max)
        aggregates += num
        if not num:
            continue
        count = 0
        while count < num:
            if not indexes:
                break
            idx = random.choice(indexes)
            rec = data[idx]
            query = rec["norm_query"]
            indexes.remove(idx)
            if query in currents:
                continue
            currents.add(query)
            rec = deepcopy(data[idx])
            rec["dtype"] = domain_x
            rec["dcount"] = idxlen
            rec["qcount"] = qcounts[query]
            samples.append(rec)
            count += 1
        print(domain_x, num, count, aggregates)
    return samples, t2idx, qcounts

今天天气好不好


In [88]:
samples, t2idx, qcounts = sample_records(data)

3rd 400 400 400
music 842 842 1242
2nd 387 387 1629
fm 188 188 1817
video 183 183 2000


In [89]:
len(set(x["query"] for x in samples))

2000

In [90]:

def get_domains(target):
    for intent in target:
        domain = intent.get("grammar_pkg_name", intent.get("domain_name", intent.get("domain")))
        if domain.startswith("music_"):
            yield "music"
            continue
        if domain.startswith("fm_"):
            yield "fm"
            continue
        if domain.startswith("video_"):
            yield "video"
            continue
        yield domain

def describe_data(data, target="grammar_topn"):
    qcounts = Counter()
    for rec in data:
        qcounts["total"] += 1
        query = rec.get("query", "")
        norm_query = clean_str(query)
        rec["norm_query"] = norm_query
        if not norm_query:
            qcounts["empty"] += 1
            continue
        if all_awakens(norm_query):
            qcounts["awakens"] += 1
            continue
        qcounts["valid"] += 1
        domains = list(set(x for x in get_domains(rec.get(target, []))))
        if not domains:
            domains = ["other"]
        qcounts.update(domains)
    print(target, qcounts)
    return qcounts


raw_grammar_topn_counts = describe_data(data)
raw_model_topn_counts = describe_data(data, target="model_topn")
sample_grammar_topn_counts = describe_data(samples)
sample_model_topn_counts = describe_data(samples, target="model_topn")

grammar_topn Counter({'total': 123534, 'valid': 120070, 'other': 56968, 'music': 20491, 'video': 14052, 'global_command': 9740, 'alarm': 7516, 'general_command_3': 6918, 'fm': 5954, 'awakens': 3418, 'weather_1': 2044, 'en_word_3': 1988, 'album_preview_3': 1684, 'general_command_2': 1529, 'album_2': 1319, 'camera_2': 1291, 'smart_home_1': 1042, 'smart_home_3': 783, 'weather': 712, 'navigation_1': 657, 'general': 583, 'camera_1': 520, 'album_1': 448, 'videocall_1': 419, 'smart_home_2': 316, 'camera_take_3': 312, 'alarm_1': 268, 'alarm_3': 247, 'videocall_2': 123, 'camera_view_3': 113, 'navigation_2': 92, 'find_3': 80, 'applet_1': 78, 'album_list_3': 66, 'go_back_3': 56, 'children_mode_2': 50, 'empty': 46, 'en_word_2': 33, 'launcher_3': 32, 'chat_1': 32, 'applet_3': 19, 'weather_test': 16, 'testyzh': 15, 'location_3': 14, 'child_mode_3': 11, 'test1': 11, 'aiqytest_2': 10, 'cricetinae': 6, 'movie': 1, 'alarm_wy': 1, 'wjj_1585906237': 1})
model_topn Counter({'total': 123534, 'valid': 120070

In [36]:
def save_jsonl(data, filename, descs):
    with open(filename, "w") as fout:
        line = json.dumps(descs, ensure_ascii=False)
        fout.write("#" + line + "\n")
        for rec in data:
            line = json.dumps(rec, ensure_ascii=False)
            fout.write(line)
            fout.write("\n")

In [91]:
save_jsonl(
    filename="2020-06-22.jsonl",
    data=samples,
    descs={
        "query_counts": {k: len(v) for k, v in t2idx.items()},
        "original_grammar_topn_counts": raw_grammar_topn_counts,
        "original_model_topn_counts": raw_model_topn_counts,
        "sample_grammar_topn_counts": sample_grammar_topn_counts,
        "sample_model_topn_counts": sample_model_topn_counts,
    })