In [54]:
import json

import pandas as pd
from pathlib2 import Path
from collections import defaultdict

from vocab import Language
from data import parse_labels

In [11]:
def read_raw_data(data_dir, lang):
    if lang == Language.chinese:
        with open(data_dir / "train_data_cn.json", 'r', encoding="utf-8") as train_fp:
            train_json = json.load(train_fp)
        with open(data_dir / "test_data_cn.json", 'r', encoding="utf-8") as test_fp:
            test_json = json.load(test_fp)
        return train_json, test_json

In [38]:
def process_raw_data(raw_data, is_train):    
    def data_gen():
        for dialogue in raw_data:
            senders = []
            texts = []
            for turn in dialogue["turns"]:
                sender = 1 if turn["sender"].startswith("c") else 0
                senders.append(sender)
                text = " ".join(turn["utterances"])
                texts.append(text)
                
            if is_train:
                customer_nugget_label, helpdesk_nugget_label, quality_label = \
                    parse_labels(dialogue["annotations"], senders)
                yield (dialogue["id"],
                       senders,
                       texts,
                       customer_nugget_label,
                       helpdesk_nugget_label,
                       quality_label)

            else:
                yield (dialogue["id"],
                       senders,
                       texts,
                       dialogue_length)
    data = [x for x in data_gen()]
    return data

In [52]:
data_dir = Path("stc3dataset/data")
lang = Language.chinese
raw_train, raw_test = read_raw_data(data_dir, lang)

In [39]:
data_train = process_raw_data(raw_train, True)

In [51]:
# No "PAD" nuggets?
for dialogue in data_train:
    for c_label in dialogue[3]:
        if c_label[0] != 0:
            print("Found")
    for h_label in dialogue[4]:
        if h_label[0] != 0:
            print("Found")

In [53]:
data_train[0]

('3830401296796826',
 [1, 0],
 ['中国电信的控制箱就这样吗？也没有人维护,信息安全和人身安全怎么保障？还是好好修修吧？这个应该不差钱吧？位于济南市新泺大街雅居园小区门对面.@中国电信 @中国电信客服 @中国电信济南客服',
  '您好,您反映的情况我们已认真记录,会及时向相关部门反馈,敬请等待[呵呵]；'],
 [array([0.        , 0.7368421 , 0.05263158, 0.        , 0.21052632],
        dtype=float32)],
 [array([0.        , 0.47368422, 0.        , 0.5263158 ], dtype=float32)],
 array([[0.        , 0.05263158, 0.6315789 , 0.21052632, 0.10526316],
        [0.        , 0.15789473, 0.57894737, 0.05263158, 0.21052632],
        [0.        , 0.        , 0.7368421 , 0.        , 0.2631579 ]],
       dtype=float32))

In [58]:
at_tags_dict = defaultdict(list)
for dialogue in data_train:
    for i, sender in enumerate(dialogue[1]):
        if sender == 1:
            tokens = dialogue[2][i].split()
            for token in tokens:
                if token.startswith('@'):
                    at_tags_dict[token].append(dialogue[0])

In [59]:
len(at_tags_dict)

1121

In [None]:
def data_to_df(data, is_train):
    df = pd.DataFrame()
    for dialogue in data:
        for i, sender in enumerate(dialogue[1]):
            if sender == 1:
                

In [None]:
df_train