In [1]:
import glob, os
TEST_DATA_FOLDERS = [
    "/n/sd7/trung/csp/data/erica/annotation/test/interview",
    "/n/sd7/trung/csp/data/erica/annotation/test/dating",
    "/n/sd7/trung/csp/data/erica/annotation/test/attentive"
]
TRAIN_DATA_FOLDERS = [
    "/n/sd7/trung/csp/data/erica/annotation/kyoto17/spring/interview",
    "/n/sd7/trung/csp/data/erica/annotation/kyoto17/spring/dating",
    "/n/sd7/trung/csp/data/erica/annotation/kyoto17/spring/attentive",
    "/n/sd7/trung/csp/data/erica/annotation/kyoto16/interview",
    "/n/sd7/trung/csp/data/erica/annotation/kyoto16/dating",
    "/n/sd7/trung/csp/data/erica/annotation/kyoto16/attentive",
    "/n/sd7/trung/csp/data/erica/annotation/kyoto16/labintro",
]

mode = "train"

from subprocess import call
from struct import unpack, pack
import numpy as np

import IPython
import wave
from pydub import AudioSegment
import re
import random

import MeCab
mt = MeCab.Tagger("-Owakati")
mt.parse('')

OUTPUT_FOLDER = "/n/sd7/trung/csp/data/erica"

BATCH_SIZE = 32

In [5]:
words = [s.strip().split(' ', 1) for s in open('/n/rd32/mimura/e2e/data/script/aps_sps/word.id', encoding='eucjp')]
decoder_map_word = {word[0].split('+')[0]: int(word[1]) for word in words}

# words = [s.strip().split(' ', 1) for s in open('/n/sd7/trung/csp/data/erica/word_ids.txt') if s != ""]
# decoder_map_word = {word[0]: int(word[1]) for word in words}
# decoder_map_word = {'<unk>': 0, '<sos>': 1, '<eos>': 2, '<sp>': 3}
print("Vocab Size:", len(decoder_map_word))

TAG_COUNT = 17
def get_tag_id(tag):
    for i, t in enumerate(['pQ', 'cQ', 'sQ', 'checkQ', 'inf', 'off', 'pro', 'sug', 'req', 'ans', 'arg', 'disarg', 'cor', 'acc', 'dec', 'bc', 'oth']):
        if tag.startswith(t): return i
    return None

Vocab Size: 30638


In [6]:
vocab = {}
def get_word_id(word, oov, line, return_word=False):
    ret = []
    start = 0
    end = len(word)
    while start < len(word):
        if start >= end: 
            if word in oov: oov[word] += 1
            else: 
                oov[word] = 1
                # print(line, word)
            # return ret + [-1]
            return ['<UNK>'] if return_word else [decoder_map_word['<UNK>']]
        if word[start:end] in decoder_map_word:
            if not return_word: ret.append(decoder_map_word[word[start:end]])
            else: ret.append(word[start:end])
            start = end
            end = len(word)
        else:
            end -= 1
    return ret

def split_words(s, oov, return_word=False):
    s = " ".join(s)
    s.replace("　", " ")
    #s = s.replace(' <', '<')
    #s = s.replace(' ', ' <sp> ')
    #s = s.replace('<', ' <')
    s = s.split(' ')
    ret = []
    for ss in s:
        ss = ss.replace(' ', '')
        if ss and ss[0] =='<':
            ret.append(ss if return_word else decoder_map_word['<sp>'])
            continue
        if ss == '': continue
        tokens = mt.parse(ss).strip()
        tokens = tokens.split(' ')
        for t in tokens:
            ret += get_word_id(t, oov, s, return_word)
    return ret
    
def preproc(line):
    for c in ['I', 'L', 'F', 'D', 'N']:
        line = re.sub(r'\(' + c + '([^\)]*)\)', r'\1', line)
    line = re.sub(r'\(\?([^\)]*)\)', r'\1', line)
    for c in list('`"「」') + ['(L', 'L)', '(F', ' L']:
        line = line.replace(c, '')
    line = re.sub(r'%.*', r'', line)
        
    line = line.replace('{LAUGH}', '')
    line = line.replace('{LAuGH}', '')
    line = line.replace('{COUGH}', '')
    line = re.sub(r'\(P [0-9]*\)', r'', line)
    
    k = line.find('/')
    if k != -1:
        tags = [line[k + 1:].split(';')[0]]
        line = line[:k]
    else: tags = []
        
    env = None
    for i in range(len(tags)):
        if tags[i].find('/') >= 0:
            env, tags[i] = tags[i].split('/', maxsplit=1)
    # tags = re.findall(r'\/(\w*)', line)
    line = re.sub(r'\/\w*', '', line)
    # if len(tags) > 0: line += "<" + tags[0] + ">"
    # if len(tags) > 1: print(tags)
    # print(line)
    return line.strip(), tags, env

In [7]:
filepaths = []
targets = []

targets = []
oov = {}
dialogs = {}
for DATA_FOLDER in TEST_DATA_FOLDERS if mode == "test" else TRAIN_DATA_FOLDERS:
    for file in glob.glob(os.path.join(DATA_FOLDER, "*.txt")):
        # print(file)
        is_start = True
        date, id, obj = os.path.basename(file)[:-4].split('_')[:3]
        longid = date + "_" + id
        year = date[:4]
        
        dialogs[longid] = []
        # if longid != "20171128_03": continue
        
        encoding = 'shift-jis'
        if longid in ['20161205_03', '20161209_05', '20161215_04', '20161208_03', '20161205_07', '20161212_01', '20161209_04', 
                      '20161209_03', '20161208_01', '20161212_03', '20161208_02', '20161207_05', '20161215_03', '20161207_04',
                     '20161205_04']: encoding = 'utf-8'
        
        lines = open(file, encoding=encoding).read().split('\n')
        # print(lines)

        wavpath = '/n/sd7/trung/csp/data/erica/dialogue/%s/%s/%s/%s_audio_mix.wav' % (year, date, longid, longid)
        if not os.path.exists(os.path.join(OUTPUT_FOLDER, "wav", longid)):
            os.mkdir(os.path.join(OUTPUT_FOLDER, "wav", longid))
        if not os.path.exists(os.path.join(OUTPUT_FOLDER, "htk", longid)):
            os.mkdir(os.path.join(OUTPUT_FOLDER, "htk", longid))
        if not os.path.exists(os.path.join(OUTPUT_FOLDER, "npy", longid)):
            os.mkdir(os.path.join(OUTPUT_FOLDER, "npy", longid))

        k = 0
        while k < len(lines):
            line = lines[k].strip()
            if longid == '20161209_05':
                start, end = line.split(' ')[1:3]
            else:
                start, end = line.split(' ')[1].split('-')
            start, end = float(start) * 1000, float(end) * 1000
            utt_id = line.split(' ')[0]

            output_wav = os.path.join(OUTPUT_FOLDER, "wav", longid, "%s_%s.wav" % (utt_id, obj))
            output_htk = os.path.join(OUTPUT_FOLDER, "htk", longid, "%s_%s.htk" % (utt_id, obj))
            output_npy = os.path.join(OUTPUT_FOLDER, "npy", longid, "%s_%s.npy" % (utt_id, obj))
            
            if line[-2:] == obj + ':':
                k += 1
                s = []
                tags = []
                envs = set()
                while k < len(lines):
                    line = lines[k].strip()
                    if line[-2:] == obj + ':': break

                    line, tag, env = preproc(line)
                    if env: envs.add(env)
                    if tag: tags += tag
                    if line != '': s.append(line)
                    # if len(tag) > 0: s.append("<" + tag[0] + ">")
                    k += 1
                
                s = split_words(s, oov, False)
                
                tags = '_'.join([';'.join(tag) for tag in tags])
                
                if len(s) >= 2 and float(end) - float(start) < 10000:
                    #targets.append(' '.join([str(k) for k in s]))
                    dialogs[longid].append((start, end, obj, output_npy, ' '.join([str(k) for k in s]), is_start))
                    is_start = False
                    #targets.append("%s %07d %07d %s %s %s %s" % (longid, start, end, utt_id, obj, ''.join(list(envs)) or '_', s))
            else: k += 1
        dialogs[longid].sort()

In [8]:
print("Dialogs:", len(dialogs))
keys = list(dialogs.keys()) * 20
random.shuffle(keys)

dlgs = [i for i in range(BATCH_SIZE)]
dlg_pos = [0 for i in range(len(dlgs))]
next_dlg = BATCH_SIZE

fo = open(os.path.join(OUTPUT_FOLDER, "train_group_by_dlg_16.txt"), 'w')
count = 0
while next_dlg <= len(keys):
    for i in range(BATCH_SIZE):
        dlg = dialogs[keys[dlgs[i]]]
        if dlg_pos[i] >= len(dlg):
            dlgs[i] = next_dlg
            next_dlg += 1
            dlg_pos[i] = 0

        utt = dlg[dlg_pos[i]]
        fo.write("%s\t%d\t%d\t2 %s 1\n" % (utt[3], 1 if dlg_pos[i] == 0 else 0, 1 if utt[2] == 'U' else 0, utt[4]))
        dlg_pos[i] += 1
        count += 1
        #print(count, "%s\t%d\t%d\t2 %s 1\n" % (utt[3], 1 if dlg_pos[i] == 0 else 0, 1 if utt[2] == 'U' else 0, utt[4]))
            
print(count)
fo.close()

Dialogs: 83
243680


In [9]:
fo = open(os.path.join(OUTPUT_FOLDER, "test_group_by_dlg_16.txt"), 'w')
count = 0
for dlg in dialogs:
    for utt in dlg:
        fo.write("%s\t%d\t%d\t2 %s 1\n" % (dlg[3], 1 if dlg_pos[i] == 0 else 0, 1 if dlg[2] == 'U' else 0, dlg[4]))
        count += 1
fo.close()
print(count)

913
