In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os.path
from pathlib import Path
import pickle
import multiprocessing
import time
import gc
from tqdm import tqdm

In [None]:
%run align_tools_cython.ipynb

The Cython extension is already loaded. To reload it, use:
  %reload_ext Cython


In [None]:
# %run _NWTW.ipynb

In [None]:
TRAIN_SET = 'toy'

In [None]:
QUERY_LIST = Path(f'cfg_files/queries.train.{TRAIN_SET}')

In [None]:
SYSTEMS = ['dtw1', 'dtw2', 'subseqdtw']
BENCHMARKS = ['matching', 'subseq10', 'subseq30', 'partialOverlap', 'pre_5', 'pre_10', 'pre_15', 'pre_20',
             'post_5', 'post_10', 'post_15', 'post_20', 'pre_post_5', 'pre_post_10', 'pre_post_15', 'pre_post_20']

In [None]:
features_root = Path('../ttmp/Chopin_Mazurkas_features')
FEAT_DIRS = {}

for benchmark in BENCHMARKS:
    if 'partial' in benchmark:
        FEAT_DIRS[benchmark] = ([features_root/'partialStart', features_root/'partialEnd'])
    else:
        FEAT_DIRS[benchmark] = [features_root/f'{benchmark}', features_root/'original']

In [None]:
steps = {'dtw1': np.array([1,1,1,2,2,1]).reshape((-1,2)),
        'dtw2': np.array([1,0,0,1]).reshape((-1,2)),
        'subseqdtw': np.array([1,1,1,2,2,1]).reshape((-1,2))}
weights = {'dtw1': np.array([2,3,3]),
          'dtw2': np.array([1,1]),
          'subseqdtw': np.array([1,1,2])}

# Benchmarks

In [None]:
def get_outfile(outdir, benchmark, system, queryid):
    outpath = (outdir / benchmark / system)
    outpath.mkdir(parents=True, exist_ok=True)
    outfile = (outpath / queryid).with_suffix('.pkl')
    return outfile

In [None]:
def dtw(dtw_version, F1, F2, outfile):
    subseq = 'subseq' in dtw_version
    if subseq and (F2.shape[1] < F1.shape[1]):
        wp = alignDTW(F2, F1, steps=steps[dtw_version], weights=weights[dtw_version], downsample=1, outfile=outfile, subseq=subseq)
        wp = wp[::-1,:]
    else:
        wp = alignDTW(F1, F2, steps=steps[dtw_version], weights=weights[dtw_version], downsample=1, outfile=outfile, subseq=subseq)
    
    if wp is not None:
        pickle.dump(wp, open(outfile, 'wb'))

In [None]:
def run_all_benchmarks(outdir):
    parts_batch = []
    queryids = []
    with open(QUERY_LIST, 'r') as f:
        for line in tqdm(f):
            parts = line.strip().split(' ')
            assert len(parts) == 2
            queryid = os.path.basename(parts[0]) + '__' + os.path.basename(parts[1])
            
            parts_batch.append(parts)
            queryids.append(queryid)
            
    for benchmark in BENCHMARKS:
#         run_benchmark(benchmark, FEAT_DIRS[benchmark][0], FEAT_DIRS[benchmark][1], parts_batch[0], outdir, queryids[0])
        run_benchmark_batch(benchmark, FEAT_DIRS[benchmark][0], FEAT_DIRS[benchmark][1], parts_batch, outdir, queryids, n_cores=4)
    
            
#             for benchmark in BENCHMARKS:
#                 run_benchmark(benchmark, FEAT_DIRS[benchmark][0], FEAT_DIRS[benchmark][1], parts, outdir, queryid)

In [None]:
def run_benchmark_batch(benchmark, featdir1, featdir2, parts_batch, outdir, queryids, n_cores):
    inputs = []
    assert len(parts_batch) == len(queryids)
    
    for i in range(len(parts_batch)):
        featfile1 = (featdir1 / parts_batch[i][0]).with_suffix('.npy')
        featfile2 = (featdir2 / parts_batch[i][1]).with_suffix('.npy')
        
        F1 = np.load(featfile1)
        F2 = np.load(featfile2)

        F1[F1 == 0] = 1e-9
        F2[F2 == 0] = 1e-9
        
        for system in SYSTEMS:
            inputs.append((system, F1, F2, get_outfile(outdir, benchmark, system, queryids[i])))

    # process files in parallel
    pool = multiprocessing.Pool(processes = multiprocessing.cpu_count()-1)
    pool.starmap(dtw, inputs)
    
    
    return

In [None]:
def run_benchmark(benchmark, featdir1, featdir2, parts, outdir, queryid):
    featfile1 = (featdir1 / parts[0]).with_suffix('.npy')
    featfile2 = (featdir2 / parts[1]).with_suffix('.npy')

    F1 = np.load(featfile1)
    F2 = np.load(featfile2)
    
    F1[F1 == 0] = 1e-9
    F2[F2 == 0] = 1e-9
        
    # run all 3 baselines
    for system in SYSTEMS:
        dtw(system, F1, F2, get_outfile(outdir, benchmark, system, queryid))

In [None]:
outdir = Path(f'experiments_train/{TRAIN_SET}')
run_all_benchmarks(outdir)