In [None]:
# -*- coding: utf-8 -*-

In [73]:
from __future__ import print_function
from __future__ import division
from prettytable import PrettyTable
import textwrap
# from tabulate import tabulate
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
import sys
import string
import subprocess
import cPickle
import editdistance
import nltk
import time
import math
from collections import Counter
import IPython
from IPython import display
from nltk.corpus import stopwords
from collections import defaultdict
from timeit import default_timer as timer
from matplotlib import rcParams
import networkx as nx
import dtw
import matplotlib.patches as patches
import scipy.io.wavfile
from python_speech_features import mfcc
from IPython.display import display
from matplotlib.ticker import MultipleLocator, \
     FormatStrFormatter, AutoMinorLocator
%matplotlib inline

In [3]:
tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]    
# Scale the RGB values to the [0, 1] range, which is the format matplotlib accepts.    
for i in range(len(tableau20)):    
    r, g, b = tableau20[i]    
    tableau20[i] = (r / 255., g / 255., b / 255.)

In [9]:
def add_tick_lines(ax, words_list, axis_lim, horizontal=True, scale=100.0):
    if horizontal:
        rot = 0
    else:
        rot = 90
    tick_lines = []
    for tup in words_list:
        start_val = tup.start / scale
        end_val = tup.end / scale
        text_val = tup.word.decode('utf-8')
        if text_val not in ['sil', 'sp']:
            if horizontal:
                x_text = (start_val + end_val)/2
                y_text = axis_lim[1]+(4*scale)
                pass
            else:
                #x_text = axis_lim[0]-(4*scale)
                x_text = -(4*scale)
                y_text = axis_lim[1]-(start_val + end_val)/2
                pass
            ax.text(x_text, y_text, text_val, size=16, ha="center", va="center",\
                   rotation=rot)
        left, width = start_val, (end_val-start_val)
        bottom, height = axis_lim[1], (4*scale)
        right = left + width
        top = bottom + height
        # axes coordinates are 0,0 is bottom left and 1,1 is upper right
        p = patches.Rectangle((left, bottom), width, height,
            fill=False, clip_on=False)
        #ax.add_patch(p)
        tick_lines.append(end_val)
    return ax, tick_lines
    

In [1]:
def plot_specgram_words(wav_fname, words_list, plot_name):
    if not os.path.exists(wav_fname):
        print("wav file not found ...")

    fig = plt.figure()
    fig.set_size_inches(20, 3)
    ax = fig.add_subplot(111)
    sr, wav_data = scipy.io.wavfile.read(wav_fname)
    _ = plt.specgram(wav_data, Fs=sr, cmap=plt.cm.gist_heat)

    the_y_lim = ax.get_ylim()
    ax.grid(False)
    x_tick_lines = []
    ax, x_tick_lines = add_tick_lines(ax, words_list, the_y_lim, horizontal=True)
    ax.set_xticks(x_tick_lines, minor=True)
    ax.xaxis.grid(True, which="minor")
    for label in ax.get_yticklabels()[::2]:
        label.set_visible(False)
    ax.get_yticklabels()[-1].set_visible(True)
    if plot_name:
        fig.savefig(plot_name,format='pdf')
        fig.savefig(plot_name.replace("pdf","png"),format='png')


In [2]:
def plot_dtw(wav_1, wav_2, es_words_1, es_words_2, plot_name):
    # Read wav data into numpy arrays
    sr1, y1 = scipy.io.wavfile.read(wav_1)
    mfcc1 = mfcc(y1, sr1)
    sr2, y2 = scipy.io.wavfile.read(wav_2)
    mfcc2 = mfcc(y2, sr2)
    #print(mfcc1.shape, mfcc2.shape)
    wav_1_data = mfcc1
    wav_2_data = mfcc2
    
    # Compute dtw
    dist, cost, acc, path = dtw.dtw(wav_1_data, wav_2_data, \
                                    dist=lambda x, y: np.linalg.norm(x-y, ord=1))
    
    print("DTW distance: %f" %(dist))
    fig = plt.figure()
    fig.set_size_inches(20, 16)

    ax_0 = plt.subplot2grid((1,1),(0, 0))
    #ax_1 = plt.subplot2grid((2,1),(1, 0))
    
    #cmap = sns.diverging_palette(220, 10, as_cmap=True)
    #cmap = sns.light_palette((210, 90, 60), input="husl", as_cmap=True)
    #cmap = sns.cubehelix_palette(start=2.8, rot=.1,dark=0, light=1,as_cmap=True)
    #cmap = sns.cubehelix_palette(rot=0.3,as_cmap=True)
    #cmap = sns.dark_palette("purple", as_cmap=True)
    cmap = sns.light_palette((200, 75, 60), input="husl", as_cmap=True)
    

    ax_0 = sns.heatmap(cost, cbar=False, \
                       xticklabels=False, yticklabels=False, ax=ax_0, cmap=cmap)
    ax_0.xaxis.tick_top()
    #ax_0.invert_yaxis()
    the_x_lim, the_y_lim = ax_0.get_xlim(), ax_0.get_ylim()
    ax_0.grid(False)
    x_tick_lines, y_tick_lines = [], []
    ax_0, x_tick_lines = add_tick_lines(ax_0, es_words_2, the_y_lim, horizontal=True, scale=1.0)
    ax_0, y_tick_lines = add_tick_lines(ax_0, es_words_1, the_y_lim, horizontal=False, scale=1.0)
    
    ax_0.axvline(x_tick_lines[-5], c="k", linewidth=3, linestyle='--')   
    ax_0.axhline(y_tick_lines[-1]-y_tick_lines[-5], c='k', linewidth=3, linestyle='--')
    #print(x_tick_lines)
    #print(y_tick_lines)
    
    #print(wav_1_data.shape, wav_2_data.shape)
    #return acc, cost
    fig.savefig(plot_name, format="pdf")
    fig.savefig(plot_name.replace("pdf","png"),format='png')

In [4]:
def plot_hist_dtw(df):
    dtw_vals = df['ZRT']
    good_vals = df['ZRT'][df['ES cont match'] > 0]
    fig = plt.figure()
    fig.set_size_inches(18.5, 10.5)
    ax_dtw_thresh = fig.add_subplot(111)
    ax_dtw_thresh.set_xlabel("DTW score", fontsize=20)
    ax_dtw_thresh.set_title("DTW score histogram", fontsize=30)
    majorLocator = MultipleLocator(.01)
    minorLocator = MultipleLocator(10)
    ax_dtw_thresh.xaxis.set_major_locator(majorLocator)
    ax_dtw_thresh.xaxis.set_minor_locator(minorLocator)
    binwidth = 0.01
    max_hist_val = min(max(dtw_vals), 1.0)
        
        
    hist_plot = ax_dtw_thresh.hist(dtw_vals, \
                            bins=np.arange(min(dtw_vals), max_hist_val + binwidth, binwidth),\
                           alpha=0.5, color=tableau20[1], label="total pairs discovered")
    hist_plot_complete = ax_dtw_thresh.hist(good_vals, \
                        bins=np.arange(min(dtw_vals), max_hist_val + binwidth, binwidth),\
                        alpha=0.5, color=tableau20[0], label="pairs with content es word(s)")
    plt.tick_params(which='both', labelsize=16)
    ax_dtw_thresh.legend(prop={'size': 20})
    pass