In [17]:
import pickle
import numpy as np
import numba
from matplotlib import pyplot as plt
from sklearn.metrics.pairwise import euclidean_distances
from librosa.sequence import dtw
from matplotlib import gridspec
from speechCommon import *
import time
import os

In [18]:
hyp_dir = '../ttemp/TamperingDetection/hyp'
all_ids = '../ttemp/TamperingDetection/data/cfg_files/all.ids'
train_ids = '../ttemp/TamperingDetection/data/cfg_files/train.ids'

editTimeSec = 1

base_dir = "../ttemp/TamperingDetection"

In [19]:
#Calculates pairwise cost matrix between reference and query mfcc matrices 

def dist(query_id, editTimeSec, edit_type, piece):
    mfcc_ref = readMFCC(base_dir, query_id, piece_type = 'reference')
    mfcc_query = readMFCC(base_dir, query_id, piece_type = 'queries', edit_type=edit_type+str(piece), editTimeSec = editTimeSec)
    
    C = euclidean_distances(mfcc_query, mfcc_ref)
    return C

In [20]:
@numba.jit(nopython=True)
def NWTWDP(C, alpha, beta=20, gamma = 1):
    # 0: visible, 1: hidden
    # B: 1 Diag, 2 Right, 3 Up, 0 switch plane
    # initialization
    D = np.zeros((2, C.shape[0], C.shape[1]))
    B = np.zeros((2, C.shape[0], C.shape[1]))
    
    # bottom rows
    D[0, 0, :] = C[0, :]
    D[1, 0, :] = np.inf
    
    # first cols
    for i in range(1, C.shape[0]):
        D[0, i, 0] = D[0, i-1, 0] + alpha
        D[1, i, 0] = D[0, i, 0]
        B[0, i, 0] = 3
        B[1, i, 0] = 0
        
    # rest of the matrix
    for i in range(1, C.shape[0]):
        for j in range(1, C.shape[1]):
        
            # hidden
            # diag visible -> hidden, right in hidden, up in hidden
            costs = np.array([D[0, i-1, j-1] + gamma + alpha, np.inf, D[1, i, j-1] + gamma, D[1, i-1, j] + alpha])
            D[1, i, j] = np.min(costs)
            B[1, i, j] = np.argmin(costs)
                
            # visible
            # hidden -> visible, diag
            costs = np.array([D[1, i, j] + beta, D[0, i-1, j-1] + C[i, j]])
            D[0, i, j] = np.min(costs)
            B[0, i, j] = np.argmin(costs)
            
    return B, D

In [21]:
#@numba.jit(nopython=True)
def backtrace3D(B, D):
    p = 0
    r = D.shape[1] - 1
    c = np.argmin(D[0, -1])
    path_3D = []
    while r > 0:
        path_3D.append([p,r,c])
        if B[p, r, c] == 0 and p == 0:
            p = 1
            r -= 1
            c -= 1
        elif B[p, r, c] == 0 and p == 1:
            p = 0
        elif B[p, r, c] == 1:
            r -= 1
            c -= 1
        elif B[p, r, c] == 2:
            c -= 1
        elif B[p, r, c] == 3:
            r -= 1
    return np.asarray(path_3D)

In [38]:
#Aligns a query file with its corresponding reference file and returns the 3-D path throught the HSTW tensor
def alignNWTWDP3D(query_id, editTimeSec, edit_type, piece, Ca = 2.4, Cb = 33, gamma = 3):
    C = dist(query_id, editTimeSec, edit_type, piece)
    alpha = np.median(np.min(C, axis=1)) * Ca
    B, D = NWTWDP(C, alpha, beta=(alpha+gamma)*Cb)
    path_3D = backtrace3D(B, D)
    return path_3D, C

In [23]:
# Aligns a query file with its corresponding reference file and returns the 3-D path throught the DTW matrix
# Used for debugging
def alignDTW(query_id, edit_type, piece, weightSet = 'D1'):
    mfcc_ref = readMFCC(base_dir, query_id, piece_type = 'reference')
    mfcc_query = readMFCC(base_dir, query_id, piece_type = 'queries', edit_type=edit_type+str(piece), editTimeSec = editTimeSec)
    D, wp = dtw(mfcc_query.T, mfcc_ref.T, subseq=True, step_sizes_sigma=sigma, weights_mul=dtw_weights[weightSet])
    return wp

In [24]:
# Plots HSTW alignment along with DTW alignment

def plotHSTWAlignment(query_id, editTimeSec, edit_type, piece, endLim = 10**6):
    startTime = time.time()
    path_3D, C = alignNWTWDP3D(query_id, editTimeSec, edit_type, piece, Ca = 2.4, Cb = 33, gamma = 3)
    path_d = alignDTW(query_id, edit_type, piece)
    print(time.time() - startTime)

    path_v = path_3D[np.where(path_3D[:,0] == 0)][:,1:3]
    path_h = path_3D[np.where(path_3D[:,0] == 1)][:,1:3]

    fig, ax = plt.subplots(1,2, figsize = (20,8))
    plt.suptitle("{}-{}sec-{}{}".format(query_id,editTimeSec,edit_type,piece))
    ax[0].imshow(C[:,:endLim], aspect = 'auto', origin = 'lower')
    ax[0].scatter(path_v[:,1], path_v[:,0], color = 'b', s = 0.5, alpha = 0.3)
    ax[0].scatter(path_h[:,1], path_h[:,0], color = 'r', s = 0.5, alpha = 0.3)
    ax[0].legend(['visible','hidden'])

    ax[1].imshow(C[:,:endLim], aspect = 'auto', origin = 'lower')
    plt.show()

In [35]:
#Saves alignment for a specific query and its reference file.

def alignAndSave(filename, query_id, editTimeSec, edit_type, piece, Ca, Cb, gamma):
    if query_id == '00' or piece in [11, 12, 13]:
        return
    print(filename)
    if not os.path.exists(filename):
    #if True:
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        path_3D, C = alignNWTWDP3D(query_id, editTimeSec, edit_type, piece, Ca = Ca, Cb = Cb, gamma = gamma)
        path_3D = np.hstack((path_3D[:,1:3], path_3D[:,0,None]))
        elem = path_3D[0]
        prevPlane = elem[2]
        costs = [C[elem[0], elem[1]]]
        for elem in path_3D[1:]:
            curPlane = elem[2]
            if(curPlane == 0 and prevPlane == 1):
                continue
            else:
                costs.append(C[elem[0], elem[1]])
                
        res = {}
        res['wp'] = path_3D
        res['dist'] = np.asarray(costs)
        res['size'] = C.shape
        with open(filename, 'wb') as f:
            pickle.dump(res, f)

In [36]:
#Saves all alignments for a specified benchmark

def alignBenchmarkWithParams(benchmark, editTimeSec, editTypes, Ca, Cb, gamma):
    with open ('../ttemp/TamperingDetection/cfg_files/{}.ids'.format(benchmark), 'r') as f:
        for i, query_id in enumerate(f.read().split('\n')):
            if(query_id == ''):
                continue
            print(i, query_id)
            for edit_type in editTypes:
                for piece in range(1, 11):
                    fileFolder = 'HSTW-{}-{}-{}'.format(Ca,Cb, gamma)
                    filename = '{}/{}/{}sec/{}/{}_{}{}.pkl'.format(hyp_dir, benchmark, editTimeSec, fileFolder, query_id, edit_type, piece)
                    if not os.path.exists(filename):
                        print('aligning', query_id, edit_type, piece)
                        alignAndSave(filename, query_id, editTimeSec, edit_type, piece, Ca, Cb, gamma)

In [None]:
#Runs alignment on all queries 

editTypes = ['i','r','d','n']
editTimeSec = 1
paramsList = [[2.4,33,3]]

for params in paramsList:
    alignBenchmarkWithParams('test', editTimeSec, editTypes, params[0], params[1], params[2])