In [None]:
from swda import Transcript
import glob, os, re, random, pathlib
import numpy as np
from tqdm import tqdm as tqdm
from collections import namedtuple
import utils
from IPython.display import HTML, display
import pandas as pd
from pydub import AudioSegment

DATA_PATH = '/n/sd7/trung/csp/data/swbd'
FEATURE_PATH = '/n/sd7/trung/csp/data/swbd/feature/numpy'
htk_path = lambda dlgid, caller: os.path.join(DATA_FOLDER, "htk", "swbd", "sw0%s-%s.htk" % (dlgid, caller))

In [None]:
trans_utts = {}
dlg_utts = {}
                                                                                                                                                                                                                                                                                                                                                    
for transfile in tqdm(list(glob.glob(os.path.join(DATA_PATH, "swb_ms98_transcriptions", "*", '*', "*word.text"))), desc="Load transcript"):
    dlgid = os.path.basename(transfile)[2:6]
    
    if not os.path.exists(os.path.join(DATA_FOLDER, "wav", dlgid)):
        os.mkdir(os.path.join(DATA_FOLDER, "wav", dlgid))
                                                                                                                                
    if dlgid not in trans_utts: trans_utts[dlgid] = []
    if dlgid not in dlg_utts: dlg_utts[dlgid] = []
        
    trans_utts[dlgid] += utils.read_word_transcript_file(transfile)
    trans_utts[dlgid].sort(key=lambda utt: utt['words'][0]['start'])

In [None]:
def map_word(word):
    word = word.lower()
    word = word.replace('_1', '')
    return word
    
dlg_utts = { dlgid: [dict(
    id=i,
    start=utt['words'][0]['start'], 
    end=utt['words'][-1]['end'],
    caller=utt['caller'],
    trans_words=[map_word(word['word']) for word in utt['words']]
) for i, utt in enumerate(trans_utts[dlgid])] for dlgid in trans_utts }

print("Dataset Overview")
print("Conversations:", len(dlg_utts))
print("Utterances:", sum([len(dlg_utts[id]) for id in dlg_utts]))
print("Utterances' length (sum): %.2f hours" % (sum([sum([utt['end'] - utt['start'] for utt in dlg_utts[id]]) for id in dlg_utts]) / 3600 / 100))
print("Total length: %.2f hours" % (sum([dlg_utts[id][-1]['end'] for id in dlg_utts]) / 3600 / 100))

In [None]:
# Example of a conversation
utts = [[utt['caller'], utt['start'], utt['end'], ' '.join(utt['trans_words'])] for utt in dlg_utts[random.choice(list(dlg_utts.keys()))]]
pd.set_option('display.max_rows', 10)
pd.DataFrame(utts, columns=["caller", "start", "end", "transcript"])

In [None]:
# Extract speech
from htk import read as read_htk
sil_duration = 25  # padding
PREFIX = "swb_padding25"

total_frame_num = 0

global_mean = None
global_std = None
feature_dim = 120
input_data_std = np.zeros((feature_dim,), dtype=np.float32)
input_data_sum = np.zeros((feature_dim,), dtype=np.float32)

for desc in ["calculate mean", "calculate std", "extract speech"]:
    for dlgid in tqdm(dlg_utts, desc=desc):
        pathlib.Path(os.path.join(FEATURE_PATH, PREFIX, dlgid)).mkdir(parents=True, exist_ok=True)
        for caller in ['A', 'B']:
            utterance_dict = list(filter(lambda utt: utt['caller'] == caller, dlg_utts[dlgid]))
            #print(len(utterance_dict))
            audio_path = htk_path(dlgid, caller)

            input_data, _, _ = read_htk(audio_path)
            input_data_dict = {}
            total_frame_num = 0
            end_frame_pre = 0
        
            for i, utt in enumerate(utterance_dict):
                start_frame, end_frame = utt['start'], utt['end']
                if i == 0:
                    start_frame_extend = max(start_frame - sil_duration, 0)
                    start_frame_next = utterance_dict[1]['start'] if len(utterance_dict) > 1 else input_data.shape[0]
                    end_frame_extend = max(end_frame, min(end_frame + sil_duration, (start_frame_next + end_frame) // 2))
                    end_frame_pre = end_frame
                elif i == len(utterance_dict) - 1:
                    start_frame_extend = max(start_frame - sil_duration, (start_frame + end_frame_pre) // 2)
                    end_frame_extend = max(end_frame, min(end_frame + sil_duration, input_data.shape[0]))
                else:
                    start_frame_extend = max(start_frame - sil_duration, (start_frame + end_frame_pre) // 2)
                    start_frame_next = utterance_dict[i + 1]['start']
                    if end_frame > start_frame_next:
                        print("Warning: utterances are overlapping.")
                    end_frame_extend = max(end_frame, min(end_frame + sil_duration, (start_frame_next + end_frame) // 2))
                    end_frame_pre = end_frame
                
                input_data_utt = input_data[start_frame_extend:end_frame_extend]
                input_data_sum += np.sum(input_data_utt, axis=0)
        
                if global_mean is not None:
                    if global_std is None:  # calculate std
                        input_data_std += np.sum(np.abs(input_data_utt - global_mean) ** 2, axis=0)
                    else:  # save
                        input_utt = (input_data_utt - global_mean) / global_std
                        #print(os.path.join(DATA_FOLDER, "feature", "numpy", PREFIX, dlgid, "%s%s.npy" % (utt['id'], caller)))
                        np.save(os.path.join(FEATURE_PATH, PREFIX, dlgid, "%s%s.npy" % (utt['id'], utt['caller'])), input_utt)
                total_frame_num += end_frame_extend - start_frame_extend
                
        if global_mean is not None:
            if global_std is None:
                global_std = np.sqrt(input_data_std / (total_frame_num - 1))
        else:
            global_mean = input_data_sum / total_frame_num

In [None]:
_keys = list(dlg_utts.keys())
dlg_ids = {}
dlg_ids['test'] = ['2689', '2389', '3016', '2898', '2679', '3567', '2373', '3250', '2478', '2929', '4078', '2821', '4770', '3170', '3198', '2064', '2562', '3271', '3061', '2982']
dlg_ids['dev'] = ['3208', '2935', '3162', '2854', '2692', '2437', '3711', '2511', '3203', '3257', '2386', '3290', '3184', '2495', '2959', '3231', '2723', '3280', '3686', '3102', '3346', '2292', '3057', '3214', '2524', '2884', '2693', '2471', '2675', '2834', '2457', '3276', '3013', '2432', '2991', '3270', '2205', '2967', '2623', '3352']
dlg_ids['dev_train'] = [id for id in _keys if id not in dlg_ids['test']]
dlg_ids['train'] = [id for id in _keys if id not in dlg_ids['test'] and id not in dlg_ids['dev']]

In [None]:
def build_vocab(dlgs, output=None): # build words
    vocab_freq = {}
    for dlgid in dlgs:
        if dlgid not in dlg_utts: continue
        for utt in dlg_utts[dlgid]:
            for word in utt['trans_words']:
                word = word.lower()
                if word == '': continue
                if not 'a' <= word[0] <= 'z': continue
                if word in vocab_freq: vocab_freq[word] += 1
                else: vocab_freq[word] = 1
                    
    words = list(vocab_freq.keys())
    words.sort(key=lambda word: vocab_freq[word], reverse=True)
    words = words[:-1]
    words = ["<oov>"] + words
    
    if output is not None:
        with open(output, 'w') as f:
            f.write('\n'.join(words))
    
    return { word: i for i, word in enumerate(words) }

vocab = build_vocab(dlg_ids['dev_train'], output=os.path.join(DATA_FOLDER, "vocab", "words_swb.txt"))
print("Vocab Size:", len(vocab))

In [None]:
PREFIX = "swb"

headers = ['dialog_id', 'sound', 'start', 'end', 'sound_len', 'caller', 'dialog_act', 'text', 'target', 'predicted_text']
for mode in dlg_ids.keys():
    with open(os.path.join(DATA_FOLDER, '%s_%s.csv' % (PREFIX, mode)), 'w') as fo:
        fo.write('\t'.join(headers) + '\n')
        for dlgid in tqdm(dlg_ids[mode], desc=mode):
            if dlgid not in dlg_utts: continue
            for utt in dlg_utts[dlgid]:
                if len(utt['trans_words']) == 0: continue
                if utt['start'] >= utt['end'] - 5: continue
                fo.write('\t'.join([
                    dlgid,
                    os.path.join(DATA_FOLDER, "feature", "numpy", "swda_full", dlgid, "%d%s.npy" % (utt['id'], utt['caller'])), 
                    str(utt['start']), str(utt['end']),
                    str(utt['end'] - utt['start']),
                    utt['caller'],
                    "0",
                    ' '.join([word.lower() for word in utt['trans_words']]),
                    ' '.join([str(vocab[word.lower()]) if word.lower() in vocab else '0' for word in utt['trans_words']]),
                    ' '.join([str(vocab[word.lower()]) if word.lower() in vocab else '0' for word in utt['trans_words']])
                ]) + '\n')