In [1]:
"""
    Parses the ICSI dialogue corpus into transcript files with DA annotations. The
    transcribed text is already preprocessed, hence this script just parses it,
    generating a more friendly tsv file.
    
    The original dataset has 2083 unique DAs, but they can be mapped into a smaller
    collection of labels by using the maps provided on the classmaps/ folder.
    
    On this notebook we map the 2083 classes into 5 classes, namely: D (disruption),
    B (backchannel), F (filler), S (statement), Q (question). The Z  class (non-
    labelled) is removed from the clean dataset.
    
    [CONTENTS]
        - Parsing
        - Data peeking
        - Label occurrence statistics
        - Splitting data into train/dev/test
"""


import re
import csv
import glob
from os import path, getcwd, remove
from random import random
from collections import OrderedDict, defaultdict, Counter

import numpy as np

# the dataset transcriptions folder and transcription extension
DATASET_PATH = path.join(getcwd(), 'icsi/data/')
DATASET_ORIGINAL_EXT = '*.dadb'

# the class map to downsample 200+ to 5 classes
CLASS_MAP_PATH = 'icsi/classmaps/map_01b_expanded'

# write split files with punctuation
TRAIN_PUNCT_PATH  = path.join(getcwd(), 'clean/icsi_train.tsv')
DEV_PUNCT_PATH    = path.join(getcwd(), 'clean/icsi_dev.tsv')
TEST_PUNCT_PATH   = path.join(getcwd(), 'clean/icsi_test.tsv')

In [2]:
# split used
train_set_idx = ['Bdb001', 'Bed002', 'Bed004', 'Bed005', 'Bed008', 'Bed009', 'Bed011', 'Bed013', 'Bed014', 'Bed015', 'Bed017', 'Bmr002', 'Bmr003', 'Bmr006', 'Bmr007', 'Bmr008', 'Bmr009', 'Bmr011', 'Bmr012', 'Bmr015', 'Bmr016', 'Bmr020', 'Bmr021', 'Bmr023', 'Bmr025', 'Bmr026', 'Bmr027', 'Bmr029', 'Bmr031', 'Bns001', 'Bns002', 'Bns003', 'Bro003', 'Bro005', 'Bro007', 'Bro010', 'Bro012', 'Bro013', 'Bro015', 'Bro016', 'Bro017', 'Bro019', 'Bro022', 'Bro023', 'Bro025', 'Bro026', 'Bro028', 'Bsr001', 'Btr001', 'Btr002', 'Buw001']
dev_set_idx   = ['Bed003', 'Bed010', 'Bmr005', 'Bmr014', 'Bmr019', 'Bmr024', 'Bmr030', 'Bro004', 'Bro011', 'Bro018', 'Bro024']
test_set_idx  =  ['Bed006', 'Bed012', 'Bed016', 'Bmr001', 'Bmr010', 'Bmr022', 'Bmr028', 'Bro008', 'Bro014', 'Bro021', 'Bro027']

In [3]:
class Parser:
    
    # these transcriptions are removed according to the dataset recommendations
    BLACKLIST = ['Bmr013.dadb', 'Bmr018.dadb']
    
    def __init__(self, amount_peeks=100):
        self.classmap = self._load_classmap()
        self.samples = list()
        self.freq = defaultdict(int)
        self.amount_peeks = amount_peeks
        
    def clean(self, text):
        text = re.sub(r'\{\}', '', text)
        text = re.sub(r'(\w)-', r'\g<1> ', text)
        text = re.sub(r'\s{2,}', ' ', text)
        
        # i_d -> id, x_m_l -> xml, i_seventeen -> iseveteen
        text = re.sub(r'[_\-<>{}\|\.]', '', text)
        return text.replace('@reject@', '').strip()
        
    def _join_fragmented_text(self, text):
        """
        Transforms the timed utterances from the .dadb file format into
        a regular string.
        """
        
        nodes = text.split('|')
        nodes = [node[node.rfind('+')+1:] for node in nodes]
        text = ' '.join(nodes)
        return self.clean(text)
    
    def _load_classmap(self):
        """
        Maps the original tags to 7 tags according to the map_01_expanded:
        D (disruption), B (backchannel), F (filler), S (statement), Q
        (question) and Z (non-labelled). Note that the X tag is maped to
        itself, but it is usually associated with utterances removed from
        the dataset, as they are associated with some transcription error
        code (ie: bleeped line, no words).
        """

        map_path = path.join(getcwd(), CLASS_MAP_PATH)

        classmap = OrderedDict()
        with open(map_path, 'r') as file:
            reader = csv.reader(file, delimiter='\t')
            for line in reader:
                label, new_label = line
                classmap[label] = new_label

        return classmap
    
    def read_trans_file(self, path):
        table = dict()
        
        with open(path) as file:
            reader = csv.reader(file)
            for line in reader:
                sentence_id = line[0]
                punct_sentence = line[-1]
                
                table[sentence_id] = punct_sentence
        
        return table

    def preprocess(self, origin_path, destination_path, punctuation_file=None):
        """
        Reads a .dadb utterance file, extract the necessary information and
        generates a new tsv file that can be used for DA classification. The
        generated tsv file contains the fields id, original label, mapped label,
        error_code, speaker_id and the text.

        The transcribed text is mostly preprocessed, so we just format it into
        a regular string.
        """
        
        if path.basename(origin_path) in self.BLACKLIST:
            return
        
        if punctuation_file:
            punct_table = self.read_trans_file(punctuation_file)
            
        dest = open(destination_path, 'w')
        writer = csv.writer(dest, delimiter='\t')
        writer.writerow(['id', 'original_label', 'label', 'errcode', 'speaker', 'clean'])
        
        with open(origin_path, 'r') as origin:
            reader = csv.reader(origin, delimiter=',')
            
            # See doc/database-format.txt for details about the indices
            for line in reader:
                utterance_id = line[2]
                error_code = line[3]
                label = line[5]
                speaker = line[7]
                original_label = line[8]
                
                if not punctuation_file:
                    fragmented_text = line[4]
                    text = self._join_fragmented_text(fragmented_text)
                else:
                    text = punct_table.get(utterance_id)
                    if not text:
                        print('----> COULD NOT FIND SENTENCE <-----')

                text = self.clean(text)

                # see line 1771 of the used classmap
                if label == 's^e.%-:s.%--':
                    label = 's^e.%-:s.%-'

                label_ = self.classmap.get(label, label)
                
                # skipping error and not-labelled (always related to error-marked utterances)
                if error_code in ['B', 'D', 'Z']:
                    continue
                
                # Z = not-labelled, so it is skipped
                if label_ in [' ', '', 'Z']:
                    continue
                    
                writer.writerow([utterance_id, original_label, label_, error_code, speaker, text])

                self.freq[label_] += 1
                self.samples.append((label_, text))
                    
        dest.close()

In [4]:

def merge_transcriptions(destination_writer, origin_file_path):
    """Moves transcription parsed lines from one file to another."""
    
    with open(origin_file_path, 'r') as origin:
        reader = csv.reader(origin, delimiter='\t')
        for i, line in enumerate(reader):
            # skipping the header
            if i == 0:
                continue
                
            destination_writer.writerow(line)

            
def merge_files(destination_path, origin_files_base_path, filenames):
    """Puts all samples from a split into a single file."""
    
    i = 0
    with open(destination_path, 'w') as dest_file:
        writer = csv.writer(dest_file, delimiter='\t')
        writer.writerow(['id', 'original_label', 'label', 'errcode', 'speaker', 'clean'])
        
        for filename in filenames:
            i+=1
            file_path = path.join(origin_files_base_path, filename + '.tsv')
            merge_transcriptions(writer, file_path)


def generate_dataset_splits(train_set, dev_set, test_set, dest_train,
                            dest_dev, dest_test, use_punctuation=False):
    """Merges multiple transcript files into a single file representing the dataset."""
    
    parser = Parser()
    
    # for each transcription in the dataset folder extract its transcriptions
    for transcription in glob.glob(path.join(DATASET_PATH, DATASET_ORIGINAL_EXT)):
        filename = path.basename(transcription).split('.')[0] 
        destin_file = path.join(getcwd(), DATASET_PATH, filename + '.tsv')
        
        if use_punctuation:
            punctuation_file = transcription.split('.')[0] + '.trans'
        else:
            punctuation_file = None
            
        # reads the dadb (optionally the trans) file and generates a formatted
        # transcription file
        parser.preprocess(transcription, destin_file, punctuation_file)
        
    # merging and persisting transcription files into split files
    merge_files(dest_train, DATASET_PATH, train_set)
    merge_files(dest_dev, DATASET_PATH, dev_set)
    merge_files(dest_test, DATASET_PATH, test_set)

    return parser

In [5]:
# generating data splits with punctuation
parser = generate_dataset_splits(train_set_idx, dev_set_idx, test_set_idx,
                                 TRAIN_PUNCT_PATH, DEV_PUNCT_PATH,
                                 TEST_PUNCT_PATH,
                                 True)

In [6]:
# data peeking
print('{:7.7}\t{:90.90}'.format('-label-', '-original text (90 chars)-'))
for sample in parser.samples[:30]:
    print('{:7.7}\t{:90.90}\t'.format(sample[0], sample[1]))

-label-	-original text (90 chars)-                                                                
S      	ok  let's be done with this                                                               	
S      	ok                                                                                        	
S      	ok                                                                                        	
D      	this is ami who ==                                                                        	
S      	and this is tilman and ralf                                                               	
S      	hi                                                                                        	
S      	uh huh  nice to meet you                                                                  	
S      	hi                                                                                        	
S      	hi                                                                                        	
S

In [7]:
# Label frequencies

freqs = sorted(parser.freq.items(), key=lambda x: x[1], reverse=True)
for label, freq in freqs:
    print('{:12.12} {}'.format(label, freq))
    
print('{:12.12} {}'.format('Total', sum([v for _, v in freqs]))) #109

S            62589
D            14814
B            14164
F            7683
Q            6797
Total        106047
