In [1]:
import re
import os
import numpy as np
from proscript import *
import sentencepiece as spm
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import ast
from collections import Counter
from dataclasses import dataclass
from random import shuffle
from shutil import copyfile
import math

In [2]:
#LOAD HEROES CORPUS PATH DATA
proscript_data_file = "heroes_segment_paths_eval.txt"

In [3]:
#Auxiliary read-write functions

def puncEstimate(punc):
    if '.' in punc:
        return '.'
    elif ',' in punc:
        return ','
    elif '?' in punc:
        return '?'
    elif '!' in punc:
        return '.'
    elif ':' in punc:
        return '.'
    elif ';' in punc:
        return ','
    elif '-' in punc:
        return '.'
    else:
        return ''

def get_proscript_transcript(input_proscript, punctuation_as_token=False, reduce_punc=True):
    input_tokens = []

    for i, w in enumerate(input_proscript.word_list):
        if punctuation_as_token:
            if w.punctuation_before:
                if reduce_punc:
                    input_tokens.append(puncEstimate(w.punctuation_before))
                else:
                    input_tokens.append(w.punctuation_before)
            input_tokens.append(w.word)
            if w.punctuation_after:
                if reduce_punc:
                    input_tokens.append(puncEstimate(w.punctuation_after))
                else:
                    input_tokens.append(w.punctuation_after)
        else:
            input_tokens.append(w.punctuation_before + w.word + w.punctuation_after)
    return ' '.join(input_tokens)

#for reading old style (punctuation_before only)
def get_proscript_transcript_2(input_proscript):
    tokens = []
    for i, w in enumerate(input_proscript.word_list):
        if w.punctuation_before:
            tokens[-1] += w.punctuation_before
        tokens.append(w.word)
    del tokens[-1]
    return ' '.join(tokens)
    
    

def read_text_file(file):
	with open(file,'r') as f:
		return f.read()

In [4]:
#PARSE PROSCRIPT DATA FILE (File with paths to segment proscripts)
heroes_data_all = []
all_available_proscript_path = []  #input (EN) proscript paths
with open(proscript_data_file, 'r') as f:
    for l in f:
        es_txt_path, es_csv_path, en_txt_path, en_csv_path = l.split()
        all_available_proscript_path.append(en_csv_path)
        heroes_data_all.append((es_txt_path, es_csv_path, en_txt_path, en_csv_path))   

In [None]:
#SPLIT HEROES DATASET INTO DEVELOPMENT AND EVALUATION SETS
heroes_data_dev = []
heroes_data_dev_train = []
heroes_data_dev_val = []

heroes_data_eval = []

heroes_data_shuffled = heroes_data_all.copy()

shuffle(heroes_data_shuffled)

dev_ratio = 0.5
index = 0
while index < len(heroes_data_shuffled) * dev_ratio:
    heroes_data_dev.append(heroes_data_shuffled[index])
    index += 1
    
while index < len(heroes_data_shuffled):
    heroes_data_eval.append(heroes_data_shuffled[index])
    index += 1    
    
#Split Dev set to training and validation
validation_size = 200
index = 0
while index < validation_size:
    heroes_data_dev_val.append(heroes_data_dev[index])
    index += 1
    
while index < len(heroes_data_dev):
    heroes_data_dev_train.append(heroes_data_dev[index])
    index += 1  
    
print("Development: %i (train: %i, val: %i)"%(len(heroes_data_dev), len(heroes_data_dev_train), len(heroes_data_dev_val)))
print("Evaluation: %i "%(len(heroes_data_eval)))

In [None]:
#PREPARE TEXT DATASET FROM HEROES SEGMENTS, PREPARE PATH FILES FOR SETS
heroes_txt_en_path = 'onmt-heroes/heroes.en.txt'
heroes_txt_es_path = 'onmt-heroes/heroes.es.txt'

def data2textfiles(dataset, en_txt_output, es_txt_output, get_en_transcript_from='txt', get_es_transcript_from='txt'):
    with open(en_txt_output, 'w') as f_en, open(es_txt_output, 'w') as f_es:
        for es_txt_path, es_csv_path, en_txt_path, en_csv_path in dataset:
            if get_en_transcript_from=='txt':
                en_transcript = read_text_file(en_txt_path).lower()
            elif get_en_transcript_from=='csv':
                p_en = Proscript()
                p_en.from_file(en_csv_path)
                en_transcript = get_proscript_transcript_2(p_en)
                
                
            if get_es_transcript_from=='txt':
                es_transcript = read_text_file(es_txt_path).lower()
            elif get_es_transcript_from=='csv':
                p_es = Proscript()
                p_es.from_file(es_csv_path)
                es_transcript = get_proscript_transcript_2(p_es)
                

            f_es.write(es_transcript + '\n')
            f_en.write(en_transcript + '\n')
            
def data2datafile(dataset, datafile_output):
    with open(datafile_output, 'w') as f:
        for es_txt_path, es_csv_path, en_txt_path, en_csv_path in dataset:
            f.write("%s\t%s\t%s\t%s\n"%(es_txt_path, es_csv_path, en_txt_path, en_csv_path))
                
# data2textfiles(heroes_data_all, 'onmt-heroes/heroes_all.en.txt', 'onmt-heroes/heroes_all.es.txt')
# data2textfiles(heroes_data_eval, 'onmt-heroes/heroes_eval.en.txt', 'onmt-heroes/heroes_eval.es.txt')
# data2textfiles(heroes_data_dev, 'onmt-heroes/heroes_dev.en.txt', 'onmt-heroes/heroes_dev.es.txt')
# data2textfiles(heroes_data_dev_train, 'onmt-heroes/heroes_dev_train.en.txt', 'onmt-heroes/heroes_dev_train.es.txt')
# data2textfiles(heroes_data_dev_val, 'onmt-heroes/heroes_dev_val.en.txt', 'onmt-heroes/heroes_dev_val.es.txt')

# data2datafile(heroes_data_eval, 'heroes_segment_paths_eval.txt')
# data2datafile(heroes_data_dev_train, 'heroes_segment_paths_dev_train.txt')
# data2datafile(heroes_data_dev_val, 'heroes_segment_paths_dev_val.txt')


In [None]:
#PREPARE PunkProse DATASET FOR PROSODIC PUNCTUATION MODEL ADAPTATION FOR ENGLISH
#dev_train to train_samples
punkProse_train_samples_dir = '/Users/alp/Documents/Corpora/punk_data/heroes_corpus_Vsugardub/train_samples'
for _, _, _, en_csv_path in heroes_data_dev_train:
    dst_file = os.path.join(punkProse_train_samples_dir, os.path.basename(en_csv_path))
    copyfile(en_csv_path, dst_file)
    
#dev_val to dev_samples
punkProse_dev_samples_dir = '/Users/alp/Documents/Corpora/punk_data/heroes_corpus_Vsugardub/dev_samples'
for _, _, _, en_csv_path in heroes_data_dev_val:
    dst_file = os.path.join(punkProse_dev_samples_dir, os.path.basename(en_csv_path))
    copyfile(en_csv_path, dst_file)
    
#eval to test_samples
punkProse_test_samples_dir = '/Users/alp/Documents/Corpora/punk_data/heroes_corpus_Vsugardub/test_samples'
for _, _, _, en_csv_path in heroes_data_eval:
    dst_file = os.path.join(punkProse_test_samples_dir, os.path.basename(en_csv_path))
    copyfile(en_csv_path, dst_file)

punkProse_groundtruth_samples_dir = '/Users/alp/Documents/Corpora/punk_data/heroes_corpus_Vsugardub/test_groundtruth'
for _, _, _, en_csv_path in heroes_data_eval:
    
    dst_file = os.path.join(punkProse_groundtruth_samples_dir, os.path.splitext(os.path.basename(en_csv_path))[0] + '.txt')
    input_proscript = Proscript()
    input_proscript.from_file(en_csv_path)
    
    with open(dst_file, 'w') as f:
        f.write(get_proscript_transcript(input_proscript, punctuation_as_token=True, reduce_punc=True))

In [None]:
#PREPARE TEXT DATASET FROM punctuation recovered HEROES SEGMENTS, PREPARE PATH FILES FOR SETS

def to_punkprosed(path):
    return '/Users/alp/phdCloud/playground/punkHeroes/out_heroes_sugardub/csv/' + os.path.splitext(os.path.basename(path))[0] + '_punkProsed.csv'

heroes_data_eval_punkProsed = [(a,b,c,to_punkprosed(d)) for a,b,c,d in heroes_data_eval]  
data2textfiles(heroes_data_eval_punkProsed, 'onmt-heroes/heroes_eval_punkProsed.en.txt', 
               'onmt-heroes/heroes_eval_punkProsed.es.txt', 
               get_en_transcript_from='csv', 
              get_es_transcript_from='txt')



In [5]:
#Functions for unit counting
#Syllable counter. (Not very accurate for Spanish)
def no_syl_en(word):
    count = 0
    vowels = 'aeiouyáóúíéüö'
    word = word.lower().strip(".:;?!")
    if word[0] in vowels:
        count +=1
    for index in range(1,len(word)):
        if word[index] in vowels and word[index-1] not in vowels:
            count +=1
    if word.endswith('e'):
        count -= 1
    if word.endswith('le'):
        count+=1
    if count == 0:
        count +=1
    return count

def no_syl_es(word):
    count = 0
    vowels = 'aeiouüö'
    vowels_acento = 'áóúíé'
    dipthongs = ['ai', 'au', 'eu', 'ei', 'oi', 'ou', 'ia', 'ie', 'io', 'iu', 'ua', 'ue', 'ui', 'uo', 'ió', 'ié']
    word = word.lower().strip(".:;?!")
    if word[0] in vowels + vowels_acento :
        count +=1
    for index in range(1,len(word)):
        #print(word[index-1:index+1])
        if word[index] in vowels + vowels_acento and word[index-1:index+1] not in dipthongs:
            #print(word[index])
            count +=1
#     if 'oos' in word:
#         count+=1
    if count == 0:
        count +=1
    return count

#Count number of words, syllables and characters in each sample
def get_unit_counts(dataset):
    heroes_unit_counts_en = []  # (Word, syllables, character)
    heroes_unit_counts_es = []  # (Word, syllables, character)

    for es_txt_path, es_csv_path, en_txt_path, en_csv_path in dataset:
        #read corresponding proscripts
        en_proscript = Proscript()
        en_proscript.from_file(en_csv_path)

        es_proscript = Proscript()
        es_proscript.from_file(es_csv_path)

        words_en = [w.word for w in en_proscript.word_list]
        words_es = [w.word for w in es_proscript.word_list]

        syl_en = [no_syl_en(w) for w in words_en]
        syl_es = [no_syl_es(w) for w in words_es]

        ch_en = []
        for w in words_en:
            ch_en += w

        ch_es = []
        for w in words_es:
            ch_es += w

        heroes_unit_counts_en.append((len(words_en), sum(syl_en), len(ch_en)))
        heroes_unit_counts_es.append((len(words_es), sum(syl_es), len(ch_es))) 

    return heroes_unit_counts_en, heroes_unit_counts_es

In [9]:
#ANALYZE UNIT COUNTS IN SETS
# #Get counts for all sets (takes long)
counts_heroes_data_all_en, counts_heroes_data_all_es = get_unit_counts(heroes_data_all)
# counts_heroes_data_eval_en, counts_heroes_data_eval_es = get_unit_counts(heroes_data_eval)
# counts_heroes_data_dev_en, counts_heroes_data_dev_es = get_unit_counts(heroes_data_dev)
# counts_heroes_data_dev_train_en, counts_heroes_data_dev_train_es = get_unit_counts(heroes_data_dev_train)
# counts_heroes_data_dev_val_en, counts_heroes_data_dev_val_es = get_unit_counts(heroes_data_dev_val)

#print average word counts
print(np.mean([c[0] for c in counts_heroes_data_all_en]))
# print(np.mean([c[0] for c in counts_heroes_data_eval_en]))
# print(np.mean([c[0] for c in counts_heroes_data_dev_en]))
# print(np.mean([c[0] for c in counts_heroes_data_dev_train_en]))
# print(np.mean([c[0] for c in counts_heroes_data_dev_val_en]))

#print average syllable counts
print(np.mean([c[1] for c in counts_heroes_data_all_en]))
# print(np.mean([c[1] for c in counts_heroes_data_eval_en]))
# print(np.mean([c[1] for c in counts_heroes_data_dev_en]))
# print(np.mean([c[1] for c in counts_heroes_data_dev_train_en]))
# print(np.mean([c[1] for c in counts_heroes_data_dev_val_en]))

#plot histogram of segment lengths w.r.t word count
# to_plot = [counts_heroes_data_all_en, counts_heroes_data_eval_en, counts_heroes_data_dev_en, counts_heroes_data_dev_train_en, counts_heroes_data_dev_val_en]
# plot_no = 1
# unit_id = 0 #word
# plt.figure(figsize=(8,20))
# for counts in to_plot:
#     plt.subplot(len(to_plot), 1, plot_no)
#     x = [count[unit_id] for count in counts]
#     # the histogram of the data
#     plt.hist(x,density=1, bins=50) 
#     plt.axis([0, 40, 0, 0.2]) 
#     plt.xlabel('# Unit')
#     plt.ylabel('Probability')
#     plot_no += 1

7.9071633237822345
10.249856733524355


In [10]:
#Get unit count averages through samples
print("Word")
print(np.mean([c[0] for c in counts_heroes_data_all_en]))
print(np.mean([c[0] for c in counts_heroes_data_all_es]))
print("\nSyllable")
syl_mean_en = np.mean([c[1] for c in counts_heroes_data_all_en])
syl_mean_es = np.mean([c[1] for c in counts_heroes_data_all_es])
print(syl_mean_en)
print(syl_mean_es)
syllable_ratios = np.array([c[1] for c in counts_heroes_data_all_es]) / np.array([c[1] for c in counts_heroes_data_all_en])
print('mean ratio', np.mean(syllable_ratios))
print('std', np.std(syllable_ratios))

print("\nChar")
print(np.mean([c[2] for c in counts_heroes_data_all_en]))
print(np.mean([c[2] for c in counts_heroes_data_all_es]))

Word
7.9071633237822345
6.905730659025788

Syllable
10.249856733524355
12.80974212034384
mean ratio 1.310834443649147
std 0.3972435107732328

Char
31.702865329512893
29.389684813753583


In [11]:
#LOAD SENTENCE PIECING MODEL 
sp = spm.SentencePieceProcessor()
sp.Load("/Users/alp/Documents/Corpora/ted_en-es/sentencepieced_tedheroes/tedheroes.model")

#LOAD TRANSLATION OUTPUT
#translation_data_file = '/Users/alp/phdCloud/playground/dubbing-w-pause/onmt-heroes/test/ted-opt-heroes-noreset_step_74970.pt_heroes_eval_punkProsed.translation.out'
translation_data_file = '/Users/alp/phdCloud/playground/dubbing-w-pause/onmt-heroes/test/ted-opt-heroes-noreset_step_74970.pt_heroes.translation.out'

In [12]:
#PARSE TRANSLATION INFORMATION [of all sentences] WITH ATTENTION WEIGHTS
#Also does a proof check as some samples from the dataset have minor problems
all_input_pieces = []
all_pred_pieces = []
all_attention_weights = []
all_attention_binary = []
all_proscript = []
all_proscript_path = []

with open(translation_data_file, 'r') as f:
    sent_index = 0
    attention_first_line = True
    for l in f:
        if l.isspace():
            continue
        elif l.startswith(("GOLD", "PRED SCORE")):
            continue
        elif l.startswith("SENT"):
            input_pieces = ast.literal_eval(l[l.find(':') + 2:])
            continue
        elif l.startswith("PRED"):
            pred_pieces = l[l.find(':') + 2:].split()
            continue

        if attention_first_line:
            attention_first_line = False
            num_input_pieces = len(input_pieces)
            num_pred_pieces = len(pred_pieces)
            sent_attention_weights = np.empty([0, num_input_pieces])
            sent_attention_binary = np.empty([0, num_input_pieces])
            continue
            
        #take the rest to the attention matrices until EOS is seen
        try:
            weights = [float(w) if not w.startswith('*') else float(w[1:]) for w in l.split()[1:]]
            binary = [1 if w.startswith('*') else 0 for w in l.split()[1:]]
        except:
            print(l)
            break
        
        if not l.split()[0] == "</s>":
            sent_attention_weights = np.vstack([sent_attention_weights, weights])
            sent_attention_binary = np.vstack([sent_attention_binary, binary])
        else:
            #read corresponding proscript
            input_proscript = Proscript()
            input_proscript.from_file(all_available_proscript_path[sent_index])

            #make sure it matches with proscript info
            proscript_transcript = get_proscript_transcript(input_proscript)
            subtitle_transcript = sp.decode_pieces(input_pieces)
            
            if proscript_transcript == subtitle_transcript:
                #store information 
                all_proscript_path.append(all_available_proscript_path[sent_index])
                all_proscript.append(input_proscript)
                all_input_pieces.append(input_pieces)
                all_pred_pieces.append(pred_pieces)
                all_attention_weights.append(sent_attention_weights)
                all_attention_binary.append(sent_attention_binary)
            else:
                print("Problem at %s"%all_available_proscript_path[sent_index])
                print(subtitle_transcript)
                print(proscript_transcript)
                print("-----")
            attention_first_line = True
            sent_index += 1

In [13]:
#EXTRACT PHRASE STRUCTURE FROM INPUT SENT PROSCRIPT

@dataclass
#struct to store phrase information
class Phrase:
    label: int
    tokens: list
    transcript: str
    start_time: float
    end_time: float
    pause_after: float
#Sample phrase initialize: p1 = Phrase(tokens=['a', 'b'], transcript="a b", start_time=0.0)
        
#Get prosodic structure from proscript representation
def get_input_structure(input_proscript, min_segment_pause_interval=0.0):
    input_tokens = []
    input_token_labels = []
    curr_phrase_id = 0

    phrase_structure = []
    curr_phrase_tokens = []
    curr_phrase_transcript = ''
    get_start = True
    for i, w in enumerate(input_proscript.word_list):
        if get_start:
            #first token in phrase
            phrase_start_time = w.start_time
            get_start = False
            
        token = w.punctuation_before + w.word + w.punctuation_after #THIS IS NOT REALLY TOKEN BUT TEXT ENCLOSED IN WHITESPACE
        curr_phrase_transcript += token + ' '
        input_tokens.append(token)
        input_token_labels.append(curr_phrase_id)
        curr_phrase_tokens.append(w.word)
        
        
        #if there's a pause after the current word or if it is the last word, close the phrase
        if w.pause_after > min_segment_pause_interval or i == len(input_proscript.word_list) - 1:
            phrase = Phrase(label=curr_phrase_id,
                            tokens=curr_phrase_tokens, 
                            transcript=curr_phrase_transcript.strip(), 
                            start_time=phrase_start_time, 
                            end_time=w.end_time, 
                            pause_after=w.pause_after)
            
            phrase_structure.append(phrase)
            
            #close phrase
            curr_phrase_id += 1
            curr_phrase_tokens = []
            curr_phrase_transcript = ''
            get_start = True
    
    return input_tokens, input_token_labels, phrase_structure
    
#labels piece sequence with phrase labels w.r.t input token labels
#REFACTOR: this function can be merged to get_input_structure()
def token_to_piece_labels(input_pieces, input_tokens, input_token_labels):
    #label pieced tokens w.r.t proscript phrase labels
    input_piece_labels = []
    proscript_token_index = 0
    piece_index = 0
    curr_token_pieces = sp.EncodeAsPieces(input_tokens[proscript_token_index])
    while piece_index < len(input_pieces):
        try:
            token_piece = curr_token_pieces.pop(0)
        except:
            proscript_token_index += 1
            curr_token_pieces = sp.EncodeAsPieces(input_tokens[proscript_token_index])
            continue

        if input_pieces[piece_index] == token_piece:
            input_piece_labels.append(input_token_labels[proscript_token_index])
            piece_index += 1
        else:
            print("somethings not right in token alignment")

    return input_piece_labels

In [23]:
#ALIGNING INPUT LABELS TO PREDICTION TOKENS - THE PROPER WAY#Naive USING BINARY ATTENTION MATRIX (kopuk olabilir)
def map_piece_labels_naive(input_piece_labels, pred_pieces, sent_attention_binary):
    pred_piece_labels = []
    for piece_no, pred_piece in enumerate(pred_pieces):
        matching_input_piece_no = np.argmax(sent_attention_binary[piece_no])
        matching_input_piece = input_pieces[matching_input_piece_no]
        piece_label = input_piece_labels[matching_input_piece_no]
        pred_piece_labels.append(piece_label)
        #print(pred_piece + " - "  + matching_input_piece + " " + str(piece_label))
    
    return pred_piece_labels

#recursive function to extend sequence possibilities
def extend_possibilities(poss_label_seqs, labels, extend_by):
    new_poss_label_seqs = []
    for label_seq in poss_label_seqs:
        #print('label_seq', label_seq)
        for label in labels:
            #print('label', label)
            if label == label_seq[-1] or label == label_seq[-1] + 1:
                new_seq = label_seq + [label]
                #print('new_seq', new_seq)
                new_poss_label_seqs.append(new_seq)
    #print("...")
    if extend_by == 1:
        return new_poss_label_seqs
    else:
        extended_poss_label_seqs = []
        for seq in new_poss_label_seqs:
            #print('seq', seq)
            new_ones = extend_possibilities([seq], labels, extend_by - 1)
            #print(new_ones)
            extended_poss_label_seqs.extend(new_ones)
        return extended_poss_label_seqs

#check if pieces of same word have conflicting labels
def check_splitlabels(piece_seq, piece_label_seq):
    for i, piece in enumerate(piece_seq):
        if not piece.startswith('▁'):
            if not piece_label_seq[i - 1] == piece_label_seq[i]:
                return False
    return True

def label_mask(label_seq, hot_label):
    masked = []
    for l in label_seq:
        if l == hot_label:
            masked.append(1)
        else:
            masked.append(0)
    return masked  
    
def get_pred_labels(input_pieces, input_piece_labels, pred_pieces, sent_attention_weights):
    labels = np.unique(input_piece_labels)
    
    all_possible_pred_label_seqs = [[0]]
    all_possible_pred_label_seqs = extend_possibilities(all_possible_pred_label_seqs, labels, len(pred_pieces) - 1)
    
    #filter out seqs where input seqs are not fully represented and if there's any multilabeled words
    possible_pred_label_seqs = [seq for seq in all_possible_pred_label_seqs if labels[-1] in seq and check_splitlabels(pred_pieces, seq)]
    
    #get scores of each possible sequence
    possible_seq_scores = [1] * len(possible_pred_label_seqs)
    for i_seq, poss_label_seq in enumerate(possible_pred_label_seqs):
        #print('seq', poss_label_seq)
        for pred_piece_index, label in enumerate(poss_label_seq):
            #print('i:%i - l:%i'%(pred_piece_index,label))
            masked_weights = sent_attention_weights[pred_piece_index] * label_mask(input_piece_labels, label)
            weights_sum = np.sum(masked_weights)
            #print('weights', masked_weights)
            #print('weights_sum', weights_sum)
            possible_seq_scores[i_seq] *= weights_sum
        #print("--> seq score:", possible_seq_scores[i_seq])    
        #print('---')
        
    try:    
        best_possible_sequence = possible_pred_label_seqs[np.argmax(possible_seq_scores)]
    except:
        best_possible_sequence = [0] * len(pred_pieces)
        
    return best_possible_sequence, possible_pred_label_seqs, possible_seq_scores


In [14]:
#TRANSLATE PREDICTION PIECE LABELS TO PREDICTION TOKEN LABELS
def most_common(lst):
    data = Counter(lst)
    return max(lst, key=data.get)

def piece_to_token(pred_pieces, pred_piece_labels):
    tokens = []
    token_labels = []
    token_label_votes = []
    for piece_index, piece in enumerate(pred_pieces):
        if piece.startswith('▁'):
            tokens.append(piece[1:])
            token_label_votes.append([pred_piece_labels[piece_index]])
        else:
            tokens[-1] += piece
            token_label_votes[-1].append(pred_piece_labels[piece_index])
    
    token_labels = [most_common(token_votes) for token_votes in token_label_votes]
    
    return tokens, token_labels

In [59]:
#EXTRACT OUTPUT PHRASE STRUCTURE W.R.T INPUT PHRASE STRUCTURE
def get_output_phrase_structure(output_tokens, output_token_labels, input_structure):
    output_structure = []
    curr_phrase_id = 0
    curr_phrase_tokens = []
    no_match = False
    
    #Take care of the case that output phrases cannot match all input phrases
    if len(input_structure) > max(output_token_labels) + 1:
        no_match = True
        #make only one phrase
        phrase_tokens = [strip_punctuation(tok) for tok in curr_phrase_tokens]
        phrase = Phrase(label=0,
                        tokens=phrase_tokens, 
                        transcript=' '.join(phrase_tokens), 
                        start_time=input_structure[0].start_time, 
                        end_time=input_structure[-1].end_time, 
                        pause_after=0.0)
        output_structure.append(phrase)
    else:
        for i, (tok, l) in enumerate(zip(output_tokens, output_token_labels)):
            #print(tok)
            t = strip_punctuation(tok) 
            if not l == curr_phrase_id:
                #close phrase
                phrase = Phrase(label=curr_phrase_id,
                                tokens=curr_phrase_tokens, 
                                transcript=' '.join(curr_phrase_tokens), 
                                start_time=input_structure[curr_phrase_id].start_time, 
                                end_time=input_structure[curr_phrase_id].end_time, 
                                pause_after=input_structure[curr_phrase_id].pause_after)

                output_structure.append(phrase)
                curr_phrase_id += 1
                if t:
                    curr_phrase_tokens = [t]
            else:
                if t:
                    curr_phrase_tokens.append(t)
        else:
            #get the final phrase
            phrase = Phrase(label=curr_phrase_id,
                                tokens=curr_phrase_tokens, 
                                transcript=' '.join(curr_phrase_tokens), 
                                start_time=input_structure[curr_phrase_id].start_time, 
                                end_time=input_structure[curr_phrase_id].end_time, 
                                pause_after=input_structure[curr_phrase_id].pause_after)

            output_structure.append(phrase)
            
    return output_structure, no_match

def strictly_increasing(L):
    return all(x<=y for x, y in zip(L, L[1:]))

#strips punctuation from beginning and end of token
def strip_punctuation(token_string):
    word_being_processed = token_string
    punc_after = ""
    punc_before = ""
    
    if not re.search(r"\w", word_being_processed):
        return ''
    
    if re.search(r"^\W", word_being_processed):
        punc = word_being_processed[:re.search(r"\w", word_being_processed).start()]
        punc_before += punc
        word_being_processed = word_being_processed[re.search(r"\w", word_being_processed).start():]

    #check end again (issue with quotations)
    word_reversed = word_being_processed[::-1]
    if re.search(r"^\W",word_reversed):
        punc = word_reversed[:re.search(r"\w", word_reversed).start()][::-1]
        punc_after = punc + punc_after
        word_being_processed = word_reversed[re.search(r"\w", word_reversed).start():][::-1]

    return word_being_processed

In [16]:
#Show attention matrix
#Sample call: show_attention(input_pieces, pred_pieces, sent_attention_weights)
color_palette = ['blue', 'green', 'red', 'brown', 'magenta', 'black', 'cyan']

def show_attention(input_words, output_words, attentions, input_word_labels = None, pred_word_labels = None):
    # Set up figure with colorbar
    fig = plt.figure(figsize=(12, 12))
    
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions, cmap='bone_r')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticklabels([''] + input_words, rotation=90, fontsize=15)
    ax.set_yticklabels([''] + output_words, fontsize=15)

    # Show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    #return plt.gca().get_xticklabels(), plt.gca().get_yticklabels()
    
    #Show labels of words as colors
    if input_word_labels:
        for i, word_tick in enumerate(plt.gca().get_xticklabels()[1:-1]):
            word_tick.set_color(color_palette[input_word_labels[i] % len(color_palette)]) 
    
    if pred_word_labels:
        for i, word_tick in enumerate(plt.gca().get_yticklabels()[1:-1]):
            word_tick.set_color(color_palette[pred_word_labels[i] % len(color_palette)]) 
        
    plt.show()
    plt.close()

#sample call: show_attention(input_pieces, pred_pieces, sent_attention_weights, input_piece_labels, pred_piece_labels)

In [43]:
#CALCULATE AVERAGE NUMBER OF SYLLABLES SIMILARITY PENALTY

syl_std = 0.38
syl_mean = 1.30

def get_syllable_similarity_penalty(utt_en, utt_es):
    total_factor = 0
    
    if not len(utt_en) == len(utt_es):
        en_tokens = []
        for phrase in utt_en:
            en_tokens.extend(phrase.tokens)
            
        es_tokens = []
        for phrase in utt_es:
            es_tokens.extend(phrase.tokens)
            
        syl_en = sum([no_syl_en(word) for word in en_tokens])
        syl_es = sum([no_syl_es(word) for word in es_tokens])
        
        #print(phrase_en.transcript)
        #print(phrase_es.transcript)
        ratio = syl_es/syl_en
        factor = ratio
        
        #print("en:%i, es:%i - %f, f:%f"%(syl_en, syl_es, ratio, factor))
        #print("...")
        total_factor += factor
    else:    
        for phrase_en, phrase_es in zip(utt_en, utt_es):
            syl_en = sum([no_syl_en(word) for word in phrase_en.tokens])
            syl_es = sum([no_syl_es(word) for word in phrase_es.tokens])

            #print(phrase_en.transcript)
            #print(phrase_es.transcript)
            ratio = syl_es/syl_en
            factor = ratio

            #print("en:%i, es:%i - %f, f:%f"%(syl_en, syl_es, ratio, factor))
            #print("...")
            total_factor += factor
    
    avg_factor = total_factor / len(utt_en)
    
    normalized = (avg_factor - syl_mean) / syl_std
    
    penalty = abs(normalized)
        
    return penalty

In [60]:
#BATCH PROCESS INPUT SEGMENTS->OUTPUT SEGMENTS
all_input_tokens = []
all_input_token_labels = []
all_input_piece_labels = []
all_pred_piece_labels = []
all_pred_tokens = []
all_pred_token_labels = []
all_input_phrase_structure = []
all_pred_phrase_structure = []
all_syllable_similarity_penalty = []

MIN_SEGMENT_PAUSE_INTERVAL = 0.3
OUTPUT_DIR = 'samples/batchpro'

severe_problematic= 0
count = 0

for segment_no, (input_pieces, pred_pieces, input_proscript, sent_attention_binary, sent_attention_weights, proscript_path) in enumerate(zip(all_input_pieces, all_pred_pieces, all_proscript, all_attention_binary, all_attention_weights, all_proscript_path)):
    try:
        input_tokens, input_token_labels, input_phrase_structure = get_input_structure(input_proscript, MIN_SEGMENT_PAUSE_INTERVAL)
        input_piece_labels = token_to_piece_labels(input_pieces, input_tokens, input_token_labels)
        pred_piece_labels, _, _ = get_pred_labels(input_pieces, input_piece_labels, pred_pieces, sent_attention_weights)
        pred_tokens, pred_token_labels = piece_to_token(pred_pieces, pred_piece_labels)
        pred_phrase_structure, problem_matching = get_output_phrase_structure(pred_tokens, pred_token_labels, input_phrase_structure)

        syllable_similarity_penalty = get_syllable_similarity_penalty(input_phrase_structure, pred_phrase_structure)
    except:
        input_tokens, input_token_labels, input_piece_labels, pred_piece_labels, pred_tokens, pred_token_labels = None, None, None, None, None, None
        print("Somethings wrong processing %i"%segment_no)
        severe_problematic += 1
        
    if problem_matching:
        print("Couldn't match all phrases in %i"%segment_no)
        
    all_input_tokens.append(input_tokens)
    all_input_token_labels.append(input_token_labels)
    all_input_piece_labels.append(input_piece_labels)
    all_input_phrase_structure.append(input_phrase_structure)
    
    all_pred_piece_labels.append(pred_piece_labels)
    all_pred_tokens.append(pred_tokens)
    all_pred_token_labels.append(pred_token_labels)
    all_pred_phrase_structure.append(pred_phrase_structure)
    
    all_syllable_similarity_penalty.append(syllable_similarity_penalty)
    
    #output files for synthesis
    #1. plain translation string
    sample_id = ''.join(os.path.splitext(os.path.basename(proscript_path))[0].split("_eng_aligned")).replace('eng', '').replace('heroes_', '')
    translation_output_file = os.path.join(OUTPUT_DIR, sample_id + '.translation.transcript.txt')
    full_translation = ' '.join(pred_tokens)
    with open (translation_output_file, 'w') as f_txt:
        f_txt.write(full_translation)

    #2. phrase structure
    phrase_output_file = os.path.join(OUTPUT_DIR, sample_id + '.translation.structure.txt')
    with open(phrase_output_file, 'w') as f:
        for st in pred_phrase_structure:
            f.write("%f\t%f\t%f\t%s\n"%(st.start_time, st.end_time, st.pause_after, st.tokens))

    count += 1
    print('at', count)
    
print("Processed %i. %i problematic"%(count, severe_problematic))

at 1
at 2
at 3
at 4
at 5
at 6
at 7
at 8
at 9
at 10
at 11
at 12
at 13
at 14
at 15
at 16
at 17
at 18
at 19
at 20
at 21
at 22
at 23
at 24
at 25
at 26
at 27
at 28
at 29
at 30
at 31
at 32
at 33
at 34
at 35
at 36
at 37
at 38
at 39
at 40
at 41
at 42
at 43
at 44
at 45
at 46
at 47
at 48
at 49
at 50
at 51
at 52
at 53
at 54
at 55
at 56
at 57
at 58
at 59
at 60
at 61
at 62
at 63
at 64
at 65
at 66
at 67
at 68
at 69
at 70
at 71
at 72
at 73
at 74
at 75
at 76
at 77
at 78
at 79
at 80
at 81
at 82
at 83
at 84
at 85
at 86
at 87
at 88
at 89
at 90
at 91
at 92
at 93
at 94
at 95
at 96
at 97
at 98
at 99
at 100
at 101
at 102
at 103
at 104
at 105
at 106
at 107
at 108
at 109
at 110
at 111
Couldn't match all phrases in 111
at 112
at 113
at 114
at 115
at 116
at 117
at 118
at 119
at 120
at 121
at 122
at 123
at 124
at 125
at 126
at 127
at 128
at 129
at 130
at 131
at 132
at 133
at 134
Couldn't match all phrases in 134
at 135
at 136
at 137
at 138
at 139
at 140
at 141
at 142
at 143
at 144
at 145
at 146
at 147
at 148
at 1

at 1181
at 1182
at 1183
at 1184
at 1185
at 1186
at 1187
at 1188
at 1189
at 1190
at 1191
at 1192
at 1193
at 1194
at 1195
at 1196
at 1197
at 1198
at 1199
at 1200
at 1201
at 1202
at 1203
at 1204
at 1205
at 1206
at 1207
at 1208
at 1209
at 1210
at 1211
at 1212
at 1213
at 1214
at 1215
at 1216
at 1217
at 1218
at 1219
at 1220
at 1221
at 1222
at 1223
at 1224
at 1225
at 1226
at 1227
at 1228
at 1229
at 1230
at 1231
at 1232
at 1233
at 1234
at 1235
at 1236
at 1237
at 1238
at 1239
at 1240
at 1241
at 1242
Couldn't match all phrases in 1242
at 1243
at 1244
at 1245
at 1246
at 1247
at 1248
at 1249
at 1250
at 1251
at 1252
at 1253
at 1254
at 1255
at 1256
at 1257
at 1258
at 1259
at 1260
at 1261
at 1262
at 1263
at 1264
at 1265
at 1266
at 1267
at 1268
at 1269
at 1270
at 1271
at 1272
at 1273
at 1274
at 1275
at 1276
at 1277
at 1278
at 1279
at 1280
at 1281
at 1282
at 1283
at 1284
at 1285
at 1286
at 1287
at 1288
at 1289
at 1290
at 1291
at 1292
at 1293
at 1294
at 1295
at 1296
at 1297
at 1298
at 1299
at 1300
at 13

at 2201
at 2202
at 2203
at 2204
at 2205
at 2206
at 2207
at 2208
at 2209
at 2210
at 2211
at 2212
at 2213
at 2214
at 2215
at 2216
at 2217
at 2218
at 2219
at 2220
at 2221
at 2222
at 2223
at 2224
at 2225
at 2226
at 2227
at 2228
at 2229
at 2230
at 2231
at 2232
at 2233
at 2234
at 2235
at 2236
at 2237
at 2238
at 2239
at 2240
at 2241
at 2242
at 2243
at 2244
at 2245
at 2246
at 2247
at 2248
at 2249
at 2250
at 2251
at 2252
at 2253
at 2254
at 2255
at 2256
at 2257
at 2258
at 2259
at 2260
at 2261
at 2262
at 2263
at 2264
at 2265
at 2266
at 2267
at 2268
at 2269
at 2270
at 2271
at 2272
at 2273
at 2274
at 2275
at 2276
at 2277
at 2278
at 2279
at 2280
at 2281
at 2282
at 2283
at 2284
at 2285
at 2286
at 2287
at 2288
at 2289
at 2290
at 2291
at 2292
at 2293
at 2294
at 2295
at 2296
at 2297
at 2298
at 2299
at 2300
at 2301
at 2302
at 2303
at 2304
at 2305
at 2306
at 2307
at 2308
at 2309
at 2310
at 2311
at 2312
at 2313
at 2314
at 2315
at 2316
at 2317
at 2318
at 2319
at 2320
at 2321
at 2322
at 2323
at 2324
at 2325


at 3238
at 3239
at 3240
at 3241
at 3242
at 3243
at 3244
at 3245
at 3246
at 3247
at 3248
at 3249
at 3250
at 3251
at 3252
at 3253
at 3254
at 3255
at 3256
at 3257
at 3258
at 3259
at 3260
at 3261
at 3262
at 3263
at 3264
at 3265
at 3266
at 3267
at 3268
at 3269
at 3270
at 3271
at 3272
at 3273
at 3274
at 3275
at 3276
at 3277
at 3278
at 3279
at 3280
at 3281
at 3282
at 3283
at 3284
at 3285
at 3286
at 3287
at 3288
at 3289
at 3290
at 3291
at 3292
at 3293
at 3294
at 3295
at 3296
at 3297
at 3298
at 3299
at 3300
at 3301
at 3302
at 3303
at 3304
Couldn't match all phrases in 3304
at 3305
at 3306
at 3307
at 3308
at 3309
at 3310
at 3311
at 3312
at 3313
at 3314
at 3315
at 3316
at 3317
at 3318
at 3319
at 3320
at 3321
at 3322
at 3323
at 3324
at 3325
at 3326
at 3327
at 3328
at 3329
at 3330
at 3331
at 3332
at 3333
at 3334
at 3335
at 3336
at 3337
at 3338
at 3339
at 3340
at 3341
at 3342
at 3343
at 3344
at 3345
at 3346
at 3347
at 3348
at 3349
at 3350
at 3351
at 3352
at 3353
at 3354
at 3355
at 3356
at 3357
at 33

In [61]:
#STUDY PARTICULAR DATA SAMPLE
WORKING_SENTENCE_NO = 3133

input_pieces = all_input_pieces[WORKING_SENTENCE_NO]
pred_pieces = all_pred_pieces[WORKING_SENTENCE_NO]
input_proscript = all_proscript[WORKING_SENTENCE_NO]
sent_attention_binary = all_attention_binary[WORKING_SENTENCE_NO]
sent_attention_weights = all_attention_weights[WORKING_SENTENCE_NO]
proscript_path = all_proscript_path[WORKING_SENTENCE_NO]

print("%i - %s"%(WORKING_SENTENCE_NO, proscript_path))

# 1. read from processed batch
# input_tokens = all_input_tokens[WORKING_SENTENCE_NO]
# input_token_labels = all_input_token_labels[WORKING_SENTENCE_NO]
# input_piece_labels = all_input_piece_labels[WORKING_SENTENCE_NO]
# input_phrase_structure = all_input_phrase_structure[WORKING_SENTENCE_NO]

# pred_tokens = all_pred_tokens[WORKING_SENTENCE_NO]
# pred_piece_labels = all_pred_piece_labels[WORKING_SENTENCE_NO]
# pred_token_labels = all_pred_token_labels[WORKING_SENTENCE_NO]
# pred_phrase_structure = all_pred_phrase_structure[WORKING_SENTENCE_NO]

# 2. process your own
input_tokens, input_token_labels, input_phrase_structure = get_input_structure(input_proscript, min_segment_pause_interval=MIN_SEGMENT_PAUSE_INTERVAL)
input_piece_labels = token_to_piece_labels(input_pieces, input_tokens, input_token_labels)
#pred_piece_labels = map_piece_labels_naive(input_piece_labels, pred_pieces, sent_attention_binary)
pred_piece_labels, possible_pred_label_seqs, possible_seq_attn_scores = get_pred_labels(input_pieces, input_piece_labels, pred_pieces, sent_attention_weights)
pred_tokens, pred_token_labels = piece_to_token(pred_pieces, pred_piece_labels)
pred_phrase_structure, problem_matching = get_output_phrase_structure(pred_tokens, pred_token_labels, input_phrase_structure)

syllable_similarity_penalty = get_syllable_similarity_penalty(input_phrase_structure, pred_phrase_structure)
print('syllable_similarity_penalty', syllable_similarity_penalty)

if problem_matching:
    print("Couldn't match all input phrases")
    
print("=======")

input_segments_from_tokens = [' '.join(p.tokens) for p in input_phrase_structure]

print("Input segments (from proscript tokens)")
print(input_segments_from_tokens)


pred_segment_pieces = [[] for i in range(max(pred_piece_labels) + 1)]
for t,i in zip(pred_pieces, pred_piece_labels):
    pred_segment_pieces[i].append(t)

pred_segments_from_pieces = [' '.join(segment_pieces) for segment_pieces in pred_segment_pieces]

print('\nSegmented output pieces')
print(pred_segments_from_pieces)

pred_segment_tokens = [[] for i in range(max(pred_token_labels) + 1)]
for t,i in zip(pred_tokens, pred_token_labels):
    pred_segment_tokens[i].append(t)

pred_segments_from_tokens = [' '.join(segment_tokens) for segment_tokens in pred_segment_tokens]

print("\nSegmented output tokens")
print(pred_segments_from_tokens)

#SHOW matched input and prediction segments
print("\nMatched input and pred segments")
full_input = ' '.join(input_segments_from_tokens)
full_translation = ' '.join(pred_segments_from_tokens)
print(full_input)
print(full_translation)
print()

for input_seg, pred_seg in zip(input_segments_from_tokens, pred_segments_from_tokens):
    print(">" + input_seg)
    print("<" + pred_seg)
    print()

#output files for synthesis
#1. plain translation string
sample_id = ''.join(os.path.splitext(os.path.basename(proscript_path))[0].split("_eng_aligned")).replace('eng', '').replace('heroes_', '')
translation_output_file = 'samples/' + sample_id + '.translation.transcript.txt'
with open (translation_output_file, 'w') as f_txt:
    f_txt.write(full_translation)
    
#2. phrase structure
phrase_output_file = 'samples/' + sample_id + '.translation.structure.txt'
with open(phrase_output_file, 'w') as f:
    for st in pred_phrase_structure:
        f.write("%f\t%f\t%f\t%s\n"%(st.start_time, st.end_time, st.pause_after, st.tokens))

3133 - /Users/alp/Movies/heroes/corpus_post/heroes_s3_12/spa-eng/segments_eng/heroes_s3_12_eng_aligned_eng0094.csv


KeyboardInterrupt: 

In [63]:
input_phrase_structure

[Phrase(label=0, tokens=['something', 'like', 'that'], transcript='something like that.', start_time=0.0, end_time=0.82, pause_after=0.96),
 Phrase(label=1, tokens=['believe', 'me'], transcript='believe me.', start_time=1.78, end_time=2.53, pause_after=0.97),
 Phrase(label=2, tokens=['this', 'little', 'baby', 'will', 'be', 'in', 'your', 'life', 'for', 'at', 'least', '16', 'more', 'years'], transcript='this little baby will be in your life for at least 16 more years,', start_time=3.5, end_time=7.63, pause_after=0.89),
 Phrase(label=3, tokens=['and', 'many', 'more', 'hopefully', 'after', 'that'], transcript='and many more, hopefully, after that.', start_time=8.52, end_time=10.7, pause_after=0.84),
 Phrase(label=4, tokens=["she's", 'gonna', 'need', 'you', 'to', 'protect', 'her', 'from', 'so', 'many', 'things'], transcript="she's gonna need you to protect her from so many things,", start_time=11.54, end_time=14.82, pause_after=0.84),
 Phrase(label=5, tokens=['terrible'], transcript='terrib

In [52]:
pred_pieces

['▁¿', '▁ayuda', '▁a', '▁hacer', '▁qué', '?']

In [54]:
all_pred_pieces[0]

['▁¿',
 'y',
 '▁eso',
 '▁es',
 '▁eso',
 '?',
 '▁soy',
 '▁un',
 '▁agente',
 '▁ahora',
 '?']