In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os.path
from pathlib import Path
import pickle
import multiprocessing
import time
import gc
from tqdm import tqdm

In [2]:
import import_ipynb

In [3]:
import DTW

In [4]:
import NWTW

The Cython extension is already loaded. To reload it, use:
  %reload_ext Cython


In [5]:
import FlexDTW

The Cython extension is already loaded. To reload it, use:
  %reload_ext Cython


In [6]:
DATASET = 'train' # 'test'
VERSION = 'list'

In [7]:
QUERY_LIST = Path(f'/home/asharma/ttmp/Flex/FlexDTW/cfg_files/queries.test.full')

In [8]:
SYSTEMS = ['dtw1', 'dtw2', 'dtw3', 'subseqdtw1', 'subseqdtw2', 'subseqdtw3', 'nwtw', 'flexdtw']
BENCHMARKS = ['matching', 'subseq_20', 'subseq_30', 'subseq_40', 'partialStart', 'partialEnd', 'partialOverlap', 
              'pre_5', 'pre_10', 'pre_20', 'post_5', 'post_10', 'post_20', 'prepost_5', 'prepost_10',
              'prepost_20']

In [9]:
features_root = Path('Chopin_Mazurkas_features')
FEAT_DIRS = {}

for benchmark in BENCHMARKS:
    if benchmark == 'partialOverlap':
        FEAT_DIRS[benchmark] = ([features_root/'partialStart', features_root/'partialEnd'])
    elif 'prepost' in benchmark:
        sec = benchmark.split('_')[-1]
        FEAT_DIRS[benchmark] = ([features_root/f'pre_{sec}', features_root/f'post_{sec}'])
    else:
        FEAT_DIRS[benchmark] = [features_root/f'{benchmark}', features_root/'original']

In [10]:
steps = {'dtw1': np.array([1,1,1,2,2,1]).reshape((-1,2)),
        'dtw2': np.array([1,1,1,2,2,1]).reshape((-1,2)),
        'dtw3': np.array([1,1,1,2,2,1]).reshape((-1,2)),
        'subseqdtw1': np.array([1,1,1,2,2,1]).reshape((-1,2)),
        'subseqdtw2': np.array([1,1,1,2,2,1]).reshape((-1,2)),
        'subseqdtw3': np.array([1,1,1,2,2,1]).reshape((-1,2)),
        'nwtw': 0, # transitions are specified in NWTW algorithm
        'flexdtw': np.array([1,1,1,2,2,1]).reshape((-1,2))
        }
weights = {'dtw1': np.array([2,3,3]),
          'dtw2': np.array([1,1,1]),
          'dtw3': np.array([1,2,2]),
          'subseqdtw1': np.array([1,1,2]),
          'subseqdtw2': np.array([2,3,3]),
          'subseqdtw3': np.array([1,2,2]),
          'nwtw': 0, # weights are specified in NWTW algorithm
          'flexdtw': np.array([1.25,3,3])
          }
other_params = {
                'flexdtw': {'beta': 0.1}
               }

In [11]:
def get_outfile(outdir, benchmark, system, queryid):
    outpath = (outdir / benchmark / system)
    outpath.mkdir(parents=True, exist_ok=True)
    outfile = (outpath / queryid).with_suffix('.pkl')
    return outfile

# Experimentation Algorithms

### Testing/Audio playing 

In [44]:
## Block to play audio segments by frame ranges

from IPython.display import Audio, display
import soundfile as sf
import numpy as np
import os

# --- INPUTS ---

featfile1 = "/home/asharma/ttmp/Flex/FlexDTW/Chopin_Mazurkas_features/original/Chopin_Op017No4/Chopin_Op017No4_Rubinstein-1952_pid9075-13.npy" 
featfile2 = "/home/asharma/ttmp/Flex/FlexDTW/Chopin_Mazurkas_features/original/Chopin_Op017No4/Chopin_Op017No4_Smith-1975_pid9054-13.npy"



hop_length = 512   # must match your features

# Frame ranges (assumed valid)
f1_start, f1_end = 10000, 120000
f2_start, f2_end = 6000, 8000 

# --- Helper to map featfile → wavfile ---
def feat_to_wav(feat_path):
    name = os.path.basename(feat_path).replace(".npy", ".wav")
    if "original" in feat_path:
        # special case
        subdir = feat_path.split("original/")[-1]        # e.g. Chopin_Op017No4/Chopin_Op017No4_...
        dirpath = os.path.dirname(subdir)                # e.g. Chopin_Op017No4
        return f"/home/asharma/ttmp/Chopin_Mazurkas/wav_22050_mono/{dirpath}/{name}"
    else:
        return feat_path.replace("Chopin_Mazurkas_features", "Chopin_Mazurkas_audios").replace(".npy", ".wav")

wavfile1 = feat_to_wav(featfile1)
wavfile2 = feat_to_wav(featfile2)

print("Using WAVs:\n", wavfile1, "\n", wavfile2)

# --- Load audio (mono) ---
y1, sr1 = sf.read(wavfile1, dtype="float32", always_2d=False)
y2, sr2 = sf.read(wavfile2, dtype="float32", always_2d=False)
if y1.ndim == 2: y1 = y1.mean(axis=1)
if y2.ndim == 2: y2 = y2.mean(axis=1)

def slice_by_frames(y, sr, f_start, f_end, hop):
    s0 = int(f_start * hop)
    s1 = int(f_end   * hop)
    seg = y[s0:s1]
    dur = (s1 - s0) / sr
    print(f"frames {f_start}→{f_end} | samples {s0}→{s1} | duration ≈ {dur:.3f}s")
    return seg

print("Recording 1 segment:")
seg1 = slice_by_frames(y1, sr1, f1_start, f1_end, hop_length)
display(Audio(seg1, rate=sr1))

print("\nRecording 2 segment:")
seg2 = slice_by_frames(y2, sr2, f2_start, f2_end, hop_length)
display(Audio(seg2, rate=sr2))


Using WAVs:
 /home/asharma/ttmp/Chopin_Mazurkas/wav_22050_mono/Chopin_Op017No4/Chopin_Op017No4_Rubinstein-1952_pid9075-13.wav 
 /home/asharma/ttmp/Chopin_Mazurkas/wav_22050_mono/Chopin_Op017No4/Chopin_Op017No4_Smith-1975_pid9054-13.wav
Recording 1 segment:
frames 10000→120000 | samples 5120000→61440000 | duration ≈ 2554.195s



Recording 2 segment:
frames 6000→8000 | samples 3072000→4096000 | duration ≈ 46.440s


### Stage 1: run flex on blocks

In [13]:
def align_system(
    system,
    F1,
    F2,
    outfile=None,
    L_block='auto',         # tile size along both axes (frames). Use 'auto' or None to choose 2-5 blocks
    min_block=5,            # skip tiny ragged tiles smaller than this on any side
    plot_overlay=True,      # make a single global overlay figure
    show_block_boxes=True,  # draw dashed rectangles for each tile
    overlay_full=True       # overlay full-matrix FlexDTW path
):
    """
    Tile the FULL cost matrix C into L_block × L_block blocks (ragged edges handled),
    run FlexDTW on EACH block, store GLOBAL (i,j) paths, and visualize in Plotly.

    If L_block is 'auto' or None, choose L_block so that the matrix is split
    into 2..5 blocks per axis (try all combinations) and pick the split that
    minimizes the total "sliver" / raggedness. The resulting L_block is a square
    tile size = max(ceil(L1 / n_row), ceil(L2 / n_col)).
    """
    assert system == 'flexdtw', "This implementation targets 'flexdtw' only."

    # ---------- Build cost matrix once ----------
    L1 = F1.shape[1]
    L2 = F2.shape[1]
    if L1 == 0 or L2 == 0:
        raise ValueError("Empty features: F1 or F2 has zero length.")

    # --- Auto-select L_block to avoid tiny slivers if requested ---
    if L_block in (None, 'auto'):
        best = None  # (sliver_metric, n_row, n_col, chosen_block)
        for n_row in range(2, 6):      # try 2..5 blocks along rows
            for n_col in range(2, 6):  # try 2..5 blocks along cols
                # block sizes needed to cover each axis
                r_size = int(np.ceil(L1 / n_row))
                c_size = int(np.ceil(L2 / n_col))

                # compute sliver: how many frames would be "extra" if we tile with these sizes
                # (r_size * n_row - L1) is the total pad on rows; similarly for cols.
                pad_rows = r_size * n_row - L1
                pad_cols = c_size * n_col - L2
                total_pad = pad_rows + pad_cols

                # also compute worst tiny-piece risk: smallest tile dimension
                # (we prefer choices where r_size and c_size >= min_block)
                smallest_tile = min(r_size, c_size)

                # Metric: prefer minimal total_pad (sliver), then prefer larger smallest_tile,
                # then fewer total tiles (n_row * n_col), then smaller max block size.
                metric = (total_pad, -smallest_tile, n_row * n_col, max(r_size, c_size))

                chosen_block = max(r_size, c_size)  # create square tiles of this size
                if best is None or metric < best[0]:
                    best = (metric, n_row, n_col, chosen_block)

        # If nothing matched (shouldn't happen), fallback to previous default 4000
        if best is not None:
            _, n_row_best, n_col_best, L_block_chosen = best
            # ensure chosen block is at least min_block
            L_block = max(int(L_block_chosen), int(min_block))
            # store for debugging / downstream use (optional)
            auto_info = {
                'auto_n_row': n_row_best,
                'auto_n_col': n_col_best,
                'auto_L_block': L_block
            }
        else:
            L_block = max(4000, min_block)
            auto_info = {'auto_fallback': True}
    else:
        auto_info = None

    F1n = FlexDTW.L2norm(F1)      # (D, L1)
    F2n = FlexDTW.L2norm(F2)      # (D, L2)
    C = 1.0 - F1n.T @ F2n         # (L1, L2), cosine distance

    # ---------- Optional full-matrix FlexDTW (for overlay) ----------
    full_global = {'best_cost': None, 'wp': None}
    if overlay_full:
        beta_full = other_params['flexdtw']['beta']
        buffer_full = min(L1, L2) * (1 - (1 - beta_full) * min(L1, L2) / max(L1, L2))
        best_cost_full, wp_full, debug_full = FlexDTW.flexdtw(
            C, steps=steps['flexdtw'], weights=weights['flexdtw'], buffer=buffer_full
        )
        # switch path to (N,2) if returned as (2,N)
        if wp_full.ndim == 2 and wp_full.shape[0] == 2:
            wp_full = wp_full.T
        full_global = {'best_cost': float(best_cost_full), 'wp': wp_full}

    # ---------- Tile C into blocks ----------
    n_row = (L1 + L_block - 1) // L_block   # ceil
    n_col = (L2 + L_block - 1) // L_block   # ceil

    blocks = []
    for bi in range(n_row):
        r0 = bi * L_block
        r1 = min((bi + 1) * L_block, L1)
        for bj in range(n_col):
            c0 = bj * L_block
            c1 = min((bj + 1) * L_block, L2)

            block = C[r0:r1, c0:c1]
            R, Cw = block.shape
            if R < min_block or Cw < min_block:
                # Skip tiny ragged tiles that tend to produce trivial 2-point paths
                continue

            # Per-block buffer (scale to local sizes)
            beta = other_params['flexdtw']['beta']
            buffer_blk = min(R, Cw) * (1 - (1 - beta) * min(R, Cw) / max(R, Cw))

            # Run FlexDTW on the block
            best_cost_blk, wp_local, debug_blk = FlexDTW.flexdtw(
                block, steps=steps['flexdtw'], weights=weights['flexdtw'], buffer=buffer_blk
            )
            # switch path to (N,2) if returned as (2,N)
            if wp_local.ndim == 2 and wp_local.shape[0] == 2:
                wp_local = wp_local.T

            raw_cost_blk = float(block[wp_local[:, 0], wp_local[:, 1]].sum())
            path_len_blk = int(np.abs(np.diff(wp_local, axis=0)).sum(axis=1).sum() + 1)

            # Map local (i,j) -> GLOBAL indices
            wp_global = np.column_stack([wp_local[:, 0] + r0, wp_local[:, 1] + c0])

            blocks.append({
                'bi': bi, 'bj': bj,
                'rows': (r0, r1),
                'cols': (c0, c1),
                'Ck_shape': (R, Cw),
                'best_cost': float(best_cost_blk),
                'wp_global': wp_global,
                'raw_cost': raw_cost_blk,
                'path_len': path_len_blk
            })

    # ---------- Plotly overlay ----------
    if plot_overlay:
        fig = go.Figure()

        # (Optional) draw dashed rectangles for each tile
        if show_block_boxes:
            for b in blocks:
                (r0, r1), (c0, c1) = b['rows'], b['cols']
                fig.add_shape(
                    type="rect",
                    x0=c0, x1=c1, y0=r0, y1=r1,
                    line=dict(color="rgba(120,120,120,0.8)", width=1, dash="dash"),
                    fillcolor="rgba(0,0,0,0)",
                    layer="below"
                )

        # Per-block paths
        palette = pc.qualitative.Plotly  # good default categorical palette
        for idx, b in enumerate(blocks):
            wp = b['wp_global']
            fig.add_trace(go.Scatter(
                x=wp[:, 1], y=wp[:, 0],
                mode="lines",
                name=f"blk ({b['bi']},{b['bj']})",
                line=dict(width=2, color=palette[idx % len(palette)]),
                hovertemplate=("blk (%{customdata[0]},%{customdata[1]})<br>"
                               "j=%{x}, i=%{y}<extra></extra>"),
                customdata=np.tile([b['bi'], b['bj']], (wp.shape[0], 1))
            ))

        # Overlay full-matrix path (bold black)
        if overlay_full and full_global['wp'] is not None:
            wp = full_global['wp']
            # Optional downsampling for very long paths (purely visual)
            step = max(1, len(wp) // 5000)
            fig.add_trace(go.Scatter(
                x=wp[::step, 1], y=wp[::step, 0],
                mode="lines",
                name=f"Global DTW (cost={full_global['best_cost']:.3f})",
                line=dict(color="black", width=3),
                opacity=0.95
            ))

        # optionally annotate auto info
        if auto_info is not None:
            ann_text = f"auto L_block={auto_info.get('auto_L_block', '?')}"
            if 'auto_n_row' in auto_info:
                ann_text += f", n_row={auto_info['auto_n_row']}, n_col={auto_info['auto_n_col']}"
            fig.add_annotation(x=0.99 * L2, y=0.01 * L1,
                               text=ann_text, showarrow=False, xanchor="right", yanchor="bottom",
                               bgcolor="rgba(255,255,255,0.7)")

        fig.update_layout(
            title="All block DTW paths (global coords)" + (
                " + Global path" if overlay_full and full_global['wp'] is not None else ""
            ),
            xaxis_title="F2 frame j (global)",
            yaxis_title="F1 frame i (global)",
            width=900, height=700,
            template="plotly_white",
            legend=dict(orientation="h")
        )
        # Equal aspect; full bounds
        fig.update_xaxes(range=[0, L2], showgrid=False)
        fig.update_yaxes(range=[0, L1], showgrid=False, scaleanchor="x", scaleratio=1)
        fig.show()

    # ---------- Persist ----------
    result = {
        'C_shape': (L1, L2),
        'L_block': L_block,
        'blocks': blocks,
        'full_global': full_global,
        'C': C,
        'auto_info': auto_info
    }
    if outfile:
        pickle.dump(result, open(outfile, 'wb'))

    return C, result


In [14]:
## align_system_split_recordings: Function to align system by splitting only the reference rather than LXL. (ARCHIVE for now)

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import pickle
import plotly.graph_objects as go

def align_system_split_recordings(system, F1, F2, outfile,
                                  L=500,            # F2 window length in frames
                                  plot_individual=True,
                                  plot_overlay=True):
    
    """
    Slice ONLY F2 into non-overlapping windows of length L (last window may be shorter).
    For each window k: run FlexDTW on (ALL of F1) x (F2[:, s2:e2]).
    Store each warping path in GLOBAL coords and visualize.

    Returns:
        seg_info:   list of metadata dicts per F2 window
        paths_global: list of np.ndarray [Nk, 2] (i_global, j_global)
        best_costs: list of float costs
        full_wp_concat: vertical concat of all paths (for convenience)
    """
    assert system == 'flexdtw', "This function only handles 'flexdtw'."

    D1, L1 = F1.shape
    D2, L2 = F2.shape
    if L1 == 0 or L2 == 0:
        raise ValueError("Empty features.")

    # Number of F2 windows (ceil)
    n_win = (L2 + L - 1) // L

    paths_global = []   # list[np.ndarray [Nk,2]] of (i_global, j_global)
    best_costs = []     # list[float]
    seg_info = []       # metadata per window

    # Precompute normalized features once
    F1n = FlexDTW.L2norm(F1)        # (D, L1)
    F2n_all = FlexDTW.L2norm(F2)    # (D, L2)

    for k in range(n_win):
        # F2 window bounds (global)
        s2, e2 = k * L, min((k + 1) * L, L2)
        if e2 <= s2:
            continue

        F2_win = F2n_all[:, s2:e2]        # normalized slice
        # Local cost: ALL F1 vs F2 window (size L1 x L2k)
        Ck = 1.0 - F1n.T @ F2_win
        L1k, L2k = Ck.shape

        # Per-window buffer (same form as your code)
        beta = other_params['flexdtw']['beta']
        buffer_k = min(L1k, L2k) * (1 - (1 - beta) * min(L1k, L2k) / max(L1k, L2k))

        # Run FlexDTW (use your API; include steps/weights if required)
        best_cost, wp_local, debug = FlexDTW.flexdtw(
            Ck, steps=steps['flexdtw'], weights=weights['flexdtw'], buffer=buffer_k
            # If your impl allows: FlexDTW.flexdtw(Ck, buffer=buffer_k)
        )

        # Ensure path shape is N×2 (FlexDTW often returns 2×N)
        if wp_local.ndim == 2 and wp_local.shape[0] == 2:
            wp_local = wp_local.T


        # Map local -> GLOBAL (i from [0..L1), j from [s2..e2))
        wp_global = np.column_stack([wp_local[:, 0], wp_local[:, 1] + s2])

        # Store
        seg_info.append({
            'k': k,
            'F1_full': (0, L1),
            'F2_win': (s2, e2),
            'Ck_shape': (L1k, L2k),
            'path_len': int(len(wp_global)),
            'cost': float(best_cost),
        })
        paths_global.append(wp_global)
        best_costs.append(best_cost)

        print(f"[win {k}] F2[{s2}:{e2})  Ck={Ck.shape}  path_len={len(wp_global)}  cost={best_cost:.4f}")

        # Per-window plot (GLOBAL coords), zoomed to this F2 window across full F1
        if plot_individual:

            fig = go.Figure()

            # F2 window rectangle spanning all F1 rows (drawn behind the path)
            fig.add_shape(
                type="rect",
                x0=s2, x1=e2, y0=0, y1=L1,
                line=dict(color="rgba(80,80,80,0.9)", width=1, dash="dash"),
                fillcolor="rgba(0,0,0,0)",
                layer="below"
            )

            # Path (global coords): x=j, y=i
            fig.add_trace(go.Scatter(
                x=wp_global[:, 1], y=wp_global[:, 0],
                mode="lines",
                name=f"win {k} path",
                line=dict(width=2)
            ))

            # Layout to match your Matplotlib view
            fig.update_layout(
                title=f"F2 window {k} path (global)<br>F2[{s2}:{e2}) vs F1[0:{L1})",
                xaxis_title="F2 frame j (global)",
                yaxis_title="F1 frame i (global)",
                width=450, height=450,
                template="plotly_white",
                showlegend=False
            )
            fig.update_xaxes(range=[s2, e2], showgrid=False)
            fig.update_yaxes(range=[0, L1], showgrid=False, scaleanchor="x", scaleratio=1)
            fig.show()


    # Convenience concat (paths are independent)
    full_wp_concat = np.vstack(paths_global) if paths_global else np.zeros((0, 2), dtype=int)

    # Global overlay: all windows & all paths
    if plot_overlay and paths_global:


        fig = go.Figure()

        # Window boxes (behind everything)
        for meta in seg_info:
            s2_box, e2_box = meta['F2_win']
            fig.add_shape(
                type="rect",
                x0=s2_box, x1=e2_box, y0=0, y1=L1,
                line=dict(color="rgba(96,96,96,0.85)", width=1, dash="dash"),
                fillcolor="rgba(0,0,0,0)",
                layer="below"
            )

        # Per-window paths
        for k, wp in enumerate(paths_global):
            fig.add_trace(go.Scatter(
                x=wp[:, 1], y=wp[:, 0],
                mode="lines",
                name=f"win {k}",
                line=dict(width=2)
            ))

        # --- Global FlexDTW overlay (exactly what you had, just plotted in Plotly) ---
        C_full = 1.0 - FlexDTW.L2norm(F1).T @ FlexDTW.L2norm(F2)
        beta_full = other_params['flexdtw']['beta']
        buffer_full = min(L1, L2) * (1 - (1 - beta_full) * min(L1, L2) / max(L1, L2))
        best_cost_full, wp_full, debug_full = FlexDTW.flexdtw(
            C_full, steps=steps['flexdtw'], weights=weights['flexdtw'], buffer=buffer_full
        )
        if wp_full.ndim == 2 and wp_full.shape[0] == 2:
            wp_full = wp_full.T  # ensure (N, 2)

        step = max(1, len(wp_full) // 5000)  # optional visual downsample
        fig.add_trace(go.Scatter(
            x=wp_full[::step, 1], y=wp_full[::step, 0],
            mode="lines",
            name=f"Global DTW (cost={best_cost_full:.3f})",
            line=dict(color="black", width=3)
        ))
        # ------------------------------------------------------------------------------
        
        fig.update_layout(
            title="All F2-window paths vs FULL F1 (global coords)",
            xaxis_title="F2 frame j (global)",
            yaxis_title="F1 frame i (global)",
            width=750, height=560,
            template="plotly_white",
            legend=dict(orientation="h")
        )
        fig.update_xaxes(range=[0, L2], showgrid=False)
        fig.update_yaxes(range=[0, L1], showgrid=False, scaleanchor="x", scaleratio=1)
        fig.show()

    return seg_info, paths_global, best_costs, full_wp_concat


In [15]:
## align_system_multiple_paths: Function to align system by extracting multiple paths from a single cost block (v0.1g)
def align_system_multiple_paths(
    system,
    F1,
    F2,
    outfile=None,
    L_block=2000,          # square chunk size (L x L)
    min_block=5,           # skip tiny ragged tiles
    mask_radius=20,        # 0 = mask exact cells only
    plot_overlay=True,     # show a global overlay figure
    show_block_boxes=True, # dashed rectangles per block
    overlay_full=True      # also compute/display full-matrix path
):
    """
    For each LxL chunk: run FlexDTW twice (masking a Chebyshev corridor around
    the first path), record BOTH local paths, and (optionally) visualize.

    Returns a 'tiled_result'-style dict:

      {
        'C_shape': (L1, L2),
        'L_block': L_block,
        'C': global_cost_matrix,
        'blocks': [
          {
            'bi','bj',
            'rows': (r0,r1), 'cols': (c0,c1),
            # per-path local (global coords):
            'wp_global_p1': np.ndarray [N1,2],
            'wp_global_p2': np.ndarray [N2,2],
            # optional stats:
            'best_cost_p1': float,        # DTW objective returned by flexdtw on the block
            'best_cost_p2': float,
            'raw_cost_p1': float,         # sum of local block C along path
            'raw_cost_p2': float,
            'path_len_p1': int,           # approx Manhattan hops (+1)
            'path_len_p2': int
          }, ...
        ],
        'full_global': {'best_cost': float or None, 'wp': np.ndarray|None}
      }
    """
    import numpy as np
    import pickle

    assert system == "flexdtw", "This helper only supports 'flexdtw'."

    # ----- Build cost matrix once (cosine distance) -----
    L1 = F1.shape[1]
    L2 = F2.shape[1]
    if L1 == 0 or L2 == 0:
        raise ValueError("Empty features: F1 or F2 has zero length.")

    F1n = FlexDTW.L2norm(F1)        # (D, L1)
    F2n = FlexDTW.L2norm(F2)        # (D, L2)
    C   = 1.0 - F1n.T @ F2n         # (L1, L2), cosine distance

    # penalty size for masking
    mC = float(np.max(C)) if C.size else 1.0
    penalty = (mC if mC > 0.0 else 1.0) * 1e6

    # ---------- Optional full-matrix FlexDTW (for overlay) ----------
    full_global = {'best_cost': None, 'wp': None}
    if overlay_full:
        buffer_full = max(1, int(0.1 * min(L1, L2)))
        best_cost_full, wp_full, _ = FlexDTW.flexdtw(
            C, steps=steps['flexdtw'], weights=weights['flexdtw'], buffer=buffer_full
        )
        if wp_full.ndim == 2 and wp_full.shape[0] == 2:
            wp_full = wp_full.T
        full_global = {'best_cost': float(best_cost_full), 'wp': wp_full}

    # ----- Iterate square LxL blocks -----
    blocks = []   # list of per-block dicts (for tiled_result-style output)

    n_row = (L1 + L_block - 1) // L_block
    n_col = (L2 + L_block - 1) // L_block

    # simple in-place Chebyshev mask
    def _mask_exact_or_radius(Csub, wp_local, rad, bump):
        R, W = Csub.shape
        if wp_local.size == 0:
            return
        for (ri, ci) in wp_local:
            r = int(ri); c = int(ci)
            r0 = max(0, r - rad); r1 = min(R - 1, r + rad)
            c0 = max(0, c - rad); c1 = min(W - 1, c + rad)
            Csub[r0:r1+1, c0:c1+1] += bump  # Chebyshev square corridor (fast)

    # plotly (optional)
    fig = None
    if plot_overlay:
        try:
            import plotly.graph_objects as go
            import plotly.colors as pc
            fig = go.Figure()
            palette = pc.qualitative.Plotly
        except Exception:
            fig = None

    for bi in range(n_row):
        r0 = bi * L_block
        r1 = min((bi + 1) * L_block, L1)
        for bj in range(n_col):
            c0 = bj * L_block
            c1 = min((bj + 1) * L_block, L2)

            block = C[r0:r1, c0:c1]
            Rb, Cb = block.shape
            if Rb < min_block or Cb < min_block:
                continue

            # local buffer sized to block
            buffer_blk = max(1, int(0.1 * min(Rb, Cb)))

            # --- first path on this block ---
            C_work = block.copy()
            cost1, wp1, _ = FlexDTW.flexdtw(
                C_work, steps=steps['flexdtw'], weights=weights['flexdtw'], buffer=buffer_blk
            )
            if wp1.ndim == 2 and wp1.shape[0] == 2:
                wp1 = wp1.T

            # local stats for p1
            raw_cost_p1 = float(block[wp1[:, 0], wp1[:, 1]].sum()) if wp1.size else 0.0
            path_len_p1 = int(np.abs(np.diff(wp1, axis=0)).sum(axis=1).sum() + 1) if wp1.shape[0] > 1 else 1

            # mask corridor of path #1
            _mask_exact_or_radius(C_work, wp1, int(max(0, mask_radius)), penalty)

            # --- second path on this block ---
            cost2, wp2, _ = FlexDTW.flexdtw(
                C_work, steps=steps['flexdtw'], weights=weights['flexdtw'], buffer=buffer_blk
            )
            if wp2.ndim == 2 and wp2.shape[0] == 2:
                wp2 = wp2.T

            # local stats for p2
            raw_cost_p2 = float(block[wp2[:, 0], wp2[:, 1]].sum()) if wp2.size else 0.0
            path_len_p2 = int(np.abs(np.diff(wp2, axis=0)).sum(axis=1).sum() + 1) if wp2.shape[0] > 1 else 1

            # map to GLOBAL coords
            wp1g = np.column_stack([wp1[:, 0] + r0, wp1[:, 1] + c0]) if wp1.size else np.empty((0,2), dtype=int)
            wp2g = np.column_stack([wp2[:, 0] + r0, wp2[:, 1] + c0]) if wp2.size else np.empty((0,2), dtype=int)

            blocks.append({
                'bi': bi, 'bj': bj,
                'rows': (r0, r1),
                'cols': (c0, c1),
                'wp_global_p1': wp1g,
                'wp_global_p2': wp2g,
                'best_cost_p1': float(cost1),
                'best_cost_p2': float(cost2),
                'raw_cost_p1': raw_cost_p1,
                'raw_cost_p2': raw_cost_p2,
                'path_len_p1': path_len_p1,
                'path_len_p2': path_len_p2
            })

            # draw per-block rectangles + both paths
            if fig is not None:
                if show_block_boxes:
                    fig.add_shape(
                        type="rect",
                        x0=c0, x1=c1, y0=r0, y1=r1,
                        line=dict(color="rgba(120,120,120,0.8)", width=1, dash="dash"),
                        fillcolor="rgba(0,0,0,0)",
                        layer="below"
                    )
                col1 = palette[(2 * (bi * n_col + bj)) % len(palette)]
                col2 = palette[(2 * (bi * n_col + bj) + 1) % len(palette)]
                if wp1g.size:
                    fig.add_trace(go.Scatter(
                        x=wp1g[:, 1], y=wp1g[:, 0],
                        mode="lines",
                        name=f"blk ({bi},{bj}) p1",
                        line=dict(width=2, color=col1),
                        hovertemplate=("blk (%{customdata[0]},%{customdata[1]}) p1<br>"
                                       "j=%{x}, i=%{y}<extra></extra>"),
                        customdata=np.tile([bi, bj], (wp1g.shape[0], 1)),
                        opacity=0.7
                    ))
                if wp2g.size:
                    fig.add_trace(go.Scatter(
                        x=wp2g[:, 1], y=wp2g[:, 0],
                        mode="lines",
                        name=f"blk ({bi},{bj}) p2",
                        line=dict(width=2, color=col2),
                        hovertemplate=("blk (%{customdata[0]},%{customdata[1]}) p2<br>"
                                       "j=%{x}, i=%{y}<extra></extra>"),
                        customdata=np.tile([bi, bj], (wp2g.shape[0], 1)),
                        opacity=0.7
                    ))

    # overlay full-matrix path (bold black)
    if fig is not None and overlay_full and full_global['wp'] is not None:
        wp = full_global['wp']
        step = max(1, len(wp) // 5000)
        fig.add_trace(go.Scatter(
            x=wp[::step, 1], y=wp[::step, 0],
            mode="lines",
            name=f"Global DTW" + (f" (cost={full_global['best_cost']:.3f})" if full_global['best_cost'] is not None else ""),
            line=dict(color="black", width=3),
            opacity=0.95
        ))
        # axes/format
        fig.update_layout(
            title="Two DTW paths per L×L block",
            xaxis_title="F2 frame j (global)",
            yaxis_title="F1 frame i (global)",
            width=900, height=700,
            template="plotly_white",
            legend=dict(orientation="h")
        )
        fig.update_xaxes(range=[0, L2], showgrid=False)
        fig.update_yaxes(range=[0, L1], showgrid=False, scaleanchor="x", scaleratio=1)
        fig.show()

    # ---------- Persist + Return tiled_result ----------
    result = {
        'C_shape': (L1, L2),
        'L_block': L_block,
        'C': C,                 # include global cost matrix for later stitched scoring
        'blocks': blocks,
        'full_global': full_global
    }

    if outfile:
        pickle.dump(result, open(outfile, "wb"))

    return result




### Stage 2: piece together path

In [16]:
##parlfex_e: v0.1d: combining cost and distance comparison metrics
import numpy as np
import plotly.graph_objects as go
import plotly.colors as pc

def parflex_e(tiled_result, C_global, show_fig=True):
    """
    Stage 2 (full matching) with discontinuity penalty.
    Decision at each block (i,j) chooses predecessor p in {up, left, diag} minimizing:
        (D[p] + C_chunk[i,j] + penalty) / (L[p] + L_chunk[i,j])
    where penalty = (C_chunk[i,j] / L_chunk[i,j]) * magn_discontinuity(p -> (i,j))
    Stores unnormalized totals:
        D_chunks[i,j] = numerator chosen (D[p] + C_chunk + penalty)
        L_chunks[i,j] = denominator chosen (L[p] + L_chunk)
    B_chunks stores predecessor coordinates (pi,pj) or (-1,-1) for origin.

    Inputs: tiled_result from stage1 (must contain 'blocks' with 'bi','bj','wp_global','rows','cols' and 'C_shape').
    Returns: dict containing C_chunk, L_chunk, D_chunks, L_chunks, B_chunks, ordered_blocks, avg_cost, stitched_wp, ...
    """
    blocks = tiled_result['blocks']
    if not blocks:
        raise ValueError("Stage 2: no blocks in tiled_result['blocks'].")

    # --- block grid shape ---
    n_row = max(b['bi'] for b in blocks) + 1
    n_col = max(b['bj'] for b in blocks) + 1
    have  = {(b['bi'], b['bj']) for b in blocks}
    by_ij = {(b['bi'], b['bj']): b for b in blocks}

    # --- Build per-block unnormalized cost & path length (same logic as your parflex_stage2) ---
    INF = 1e18
    C_chunk = np.full((n_row, n_col), np.nan, dtype=float)  # unnormalized block cost
    L_chunk = np.full((n_row, n_col), np.nan, dtype=float)  # block Manhattan length

    for b in blocks:
        bi, bj = b['bi'], b['bj']

        # Prefer stage-1 raw_cost if present, else approximate 
        raw = float(b['raw_cost'])
        # else:
        #     assert(False)
        #     # wp = np.asarray(b['wp_global'])
        #     # if wp.ndim != 2 or wp.shape[0] < 1:
        #     #     raw = float(b.get('best_cost', 0.0))
        #     # else:
        #     #     plen = int(b.get('path_len', max(1, int(np.abs(np.diff(wp, axis=0)).sum(axis=1).sum() + 1))))
        #     #     raw = float(b.get('best_cost', 0.0)) * max(plen, 1)

        # if 'path_len' in b:
        plen = int(b['path_len'])
        # # else:
        # #     wp = np.asarray(b['wp_global'])
        # #     plen = int(np.abs(np.diff(wp, axis=0)).sum(axis=1).sum() + 1) if wp.shape[0] >= 2 else 1

        C_chunk[bi, bj] = raw
        L_chunk[bi, bj] = max(plen, 1.0)

    # Fill missing tiles with large penalty & unit length (same fallback as before)
    if np.isnan(C_chunk).any():
        finite = C_chunk[np.isfinite(C_chunk)]
        if finite.size == 0:
            raise ValueError("Stage 2: all block costs are missing.")
        penalty_val = np.nanpercentile(C_chunk, 95) + 6*(np.nanmedian(np.abs(finite - np.nanmedian(finite))) + 1e-6)
        C_chunk = np.where(np.isfinite(C_chunk), C_chunk, penalty_val)
    if np.isnan(L_chunk).any():
        L_chunk = np.where(np.isfinite(L_chunk), L_chunk, 1.0)

    # --- build start/end endpoints for magn_discontinuity like endpoint_distance_algorithm ---
    # Ensure each block has 'start' and 'end' in global coordinates (i_global, j_global)
    for b in blocks:
        wp = np.asarray(b['wp_global'])
        if wp.ndim != 2 or wp.shape[1] != 2:
            raise ValueError("wp_global must be (N,2) for each block.")
        s_idx = int(np.argmin(wp[:, 0] + wp[:, 1]))
        e_idx = int(np.argmax(wp[:, 0] + wp[:, 1]))
        b['start'] = (int(wp[s_idx, 0]), int(wp[s_idx, 1]))
        b['end']   = (int(wp[e_idx, 0]), int(wp[e_idx, 1]))

    def magn_discontinuity(prev_cell, cur_cell): 
        """
        Manhattan distance between prev.end and cur.start (global coords).
        If either block missing, return large value (so path avoids it if possible).
        """
        if prev_cell not in by_ij or cur_cell not in by_ij:
            return float('inf')
        ie_prev, je_prev = by_ij[prev_cell]['end']
        is_cur, js_cur = by_ij[cur_cell]['start']
        return float(abs(is_cur - ie_prev) + abs(js_cur - je_prev))


    # --- DP tables ---
    if (0, 0) not in have or (n_row-1, n_col-1) not in have:
        raise ValueError("Stage 2 full-matching requires blocks at (0,0) and (n_row-1,n_col-1).")

    # Store cumulative unnormalized numerator and denominator, and predecessor coords
    D_chunks = np.full((n_row, n_col), INF, dtype=float)   # numerator (unnormalized total cost)
    L_chunks = np.full((n_row, n_col), 0.0, dtype=float)   # denominator (total path length)
    B_chunks = np.full((n_row, n_col, 2), -1, dtype=int)   # predecessor coords (pi,pj) or (-1,-1)

    # Initialize origin
    D_chunks[0, 0] = C_chunk[0, 0]   # no discontinuity penalty at origin
    L_chunks[0, 0] = L_chunk[0, 0]
    B_chunks[0, 0] = [-1, -1]

    # DP sweep row-major
    for i in range(n_row):
        for j in range(n_col):
            if i == 0 and j == 0:
                continue
            if (i, j) not in have:
                continue

            best_ratio = float('inf')
            best_choice = None  # tuple (pi,pj, num, den)

            # iterate predecessors up, left, diag
            for (pi, pj) in ((i-1, j), (i, j-1), (i-1, j-1)):
                if pi < 0 or pj < 0:
                    continue
                if (pi, pj) not in have:
                    continue
                if not np.isfinite(D_chunks[pi, pj]) or D_chunks[pi, pj] >= INF/2:
                    continue

                # compute penalty
                # penalty = (C_chunk[i,j] / L_chunk[i,j]) * magn_discontinuity((pi,pj)->(i,j))
                base_ratio_factor = C_chunk[i, j] / max(L_chunk[i, j], 1e-12)
                mag = magn_discontinuity((pi, pj), (i, j))
                penalty = base_ratio_factor * mag

                num = D_chunks[pi, pj] + C_chunk[i, j] + penalty
                den = L_chunks[pi, pj] + L_chunk[i, j]
                ratio = num / max(den, 1.0)

                if ratio < best_ratio:
                    best_ratio = ratio
                    best_choice = (pi, pj, num, den)

            if best_choice is None:
                # unreachable (no valid predecessors)
                continue

            pi, pj, num, den = best_choice
            D_chunks[i, j] = num
            L_chunks[i, j] = den
            B_chunks[i, j] = [pi, pj]

    ei, ej = n_row - 1, n_col - 1
    # if D_chunks[ei, ej] >= INF/2:
    #     # fallback: find best reachable on bottom row or right column (like endpoint algorithm)
    #     reachable_goal = None
    #     best_val = INF
    #     for (ii, jj) in have:
    #         if (ii == ei or jj == ej) and np.isfinite(D_chunks[ii, jj]) and D_chunks[ii, jj] < best_val:
    #             best_val = D_chunks[ii, jj]
    #             reachable_goal = (ii, jj)
    #     if reachable_goal is None:
    #         raise RuntimeError("Stage 2: end block unreachable; check tiling or penalties.")
    #     else:
    #         ei, ej = reachable_goal

    # Backtrace ordered_blocks
    ordered_blocks = []
    i, j = ei, ej
    while True:
        ordered_blocks.append((i, j))
        pi, pj = int(B_chunks[i, j, 0]), int(B_chunks[i, j, 1])
        if pi == -1 and pj == -1:
            break
        i, j = pi, pj
    ordered_blocks = ordered_blocks[::-1]  # origin -> goal
    ordered_linear = [bi * n_col + bj for (bi, bj) in ordered_blocks]
    avg_cost = float(D_chunks[ei, ej] / max(L_chunks[ei, ej], 1.0))

    # --- stitch the per-block global wp paths along ordered_blocks (avoid duplicate junctions) ---
    stitched_parts = []
    last = None
    for (bi, bj) in ordered_blocks:
        wp = np.asarray(by_ij[(bi, bj)]['wp_global'], dtype=float)
        if wp.size == 0:
            continue
        if last is None:
            stitched_parts.append(wp)
            last = (wp[-1, 0], wp[-1, 1])
        else:
            # avoid duplicating identical join point
            if np.allclose(stitched_parts[-1][-1], wp[0], atol=1e-8):
                stitched_parts.append(wp[1:])
            else:
                stitched_parts.append(wp)
            last = (wp[-1, 0], wp[-1, 1])

    stitched_wp = np.vstack(stitched_parts) if stitched_parts else np.zeros((0, 2), dtype=float)

    # compute stitched raw cost and Manhattan distance and normalized cost (for reporting)
    total_normalized_from_stitched = float('inf')
    stitched_raw_cost = None
    if stitched_wp.size:
        # stitched_wp are global coords; tiled_result may or may not contain full global C matrix
        C_global = tiled_result.get('C', None)
        if C_global is not None:
            stitched_raw_cost = float(np.sum([C_global[int(i), int(j)] for (i, j) in stitched_wp.astype(int)]))
            si, sj = stitched_wp[0].astype(int)
            ei_g, ej_g = stitched_wp[-1].astype(int)
            mdist = abs(ei_g - si) + abs(ej_g - sj)
            total_normalized_from_stitched = stitched_raw_cost / mdist if mdist > 0 else float('inf')

    
            
    # --- optional plotting (similar style to parflex_stage2) ---
    if show_fig:
        L1, L2 = tiled_result['C_shape']
        fig = go.Figure()

        # 1) tile rectangles
        for b in blocks:
            (r0, r1), (c0, c1) = b['rows'], b['cols']
            fig.add_shape(
                type="rect",
                x0=c0, x1=c1, y0=r0, y1=r1,
                line=dict(color="rgba(140,140,140,0.6)", width=1, dash="dash"),
                fillcolor="rgba(0,0,0,0)",
                layer="below"
            )

        # 2) overlay all block paths (faded)
        palette = pc.qualitative.Plotly
        for idx, b in enumerate(blocks):
            wp = np.asarray(b['wp_global'])
            if wp.size:
                fig.add_trace(go.Scatter(
                    x=wp[:, 1], y=wp[:, 0],
                    mode="lines",
                    line=dict(width=2, color=palette[idx % len(palette)]),
                    name=f"blk({b['bi']},{b['bj']}) path",
                    showlegend=False,
                    opacity=0.45
                ))

        # 3) block order markers
        centers_x, centers_y, labels = [], [], []
        for t, (bi, bj) in enumerate(ordered_blocks, start=1):
            b = by_ij[(bi, bj)]
            (r0, r1), (c0, c1) = b['rows'], b['cols']
            centers_x.append((c0 + c1) / 2.0)
            centers_y.append((r0 + r1) / 2.0)
            labels.append(str(t))
        if centers_x:
            fig.add_trace(go.Scatter(
                x=centers_x, y=centers_y,
                mode="markers+text",
                marker=dict(size=18, symbol="circle-open-dot"),
                text=labels, textposition="middle center",
                name=f"chunk order (avg={avg_cost:.3f})",
                hovertemplate="order=%{text}<br>center (j=%{x:.0f}, i=%{y:.0f})<extra></extra>"
            ))

        # 4) start/end markers per block
        start_x, start_y, end_x, end_y = [], [], [], []
        for (bi, bj) in ordered_blocks:
            b = by_ij[(bi, bj)]
            (is_, js_), (ie, je) = b['start'], b['end']
            start_x.append(js_); start_y.append(is_)
            end_x.append(je);   end_y.append(ie)
        if start_x:
            fig.add_trace(go.Scatter(
                x=start_x, y=start_y, mode="markers",
                marker=dict(size=9, symbol="triangle-up", color="green"),
                name="block start (global)"
            ))
            fig.add_trace(go.Scatter(
                x=end_x, y=end_y, mode="markers",
                marker=dict(size=9, symbol="diamond", color="red"),
                name="block end (global)"
            ))

                # 5) stitched path segments (black lines inside each block only)
        if stitched_wp.size:
            seg_x, seg_y = [], []
            for k in range(len(ordered_blocks) - 1):
                b_cur = by_ij[ordered_blocks[k]]
                b_next = by_ij[ordered_blocks[k+1]]

                # current block path
                wp_cur = np.asarray(b_cur['wp_global'])
                if wp_cur.size == 0:
                    continue
                # take only the part from start->end inside this block
                seg_x.extend(wp_cur[:, 1])
                seg_y.extend(wp_cur[:, 0])

                # when switching to next block, break the trace
                fig.add_trace(go.Scatter(
                    x=seg_x, y=seg_y,
                    mode="lines",
                    line=dict(width=3, color="black"),
                    name="Stitched DTW segment" if k == 0 else None,
                    showlegend=(k == 0)
                ))
                seg_x, seg_y = [], []

            # add the last block path too
            b_last = by_ij[ordered_blocks[-1]]
            wp_last = np.asarray(b_last['wp_global'])
            if wp_last.size:
                fig.add_trace(go.Scatter(
                    x=wp_last[:, 1], y=wp_last[:, 0],
                    mode="lines",
                    line=dict(width=3, color="black"),
                    name=None,
                    showlegend=False
                ))


        # 6) optional global DTW overlay
        if isinstance(tiled_result.get('full_global'), dict):
            fg = tiled_result['full_global']
            wp_full = fg.get('wp', None)
            best_full = fg.get('best_cost', None)
            if wp_full is not None:
                wp_tmp = wp_full
                if wp_tmp.ndim == 2 and wp_tmp.shape[0] == 2:
                    wp_tmp = wp_tmp.T
                step = max(1, len(wp_tmp) // 5000)
                fig.add_trace(go.Scatter(
                    x=wp_tmp[::step, 1], y=wp_tmp[::step, 0],
                    mode="lines",
                    line=dict(width=2, color="rgba(0,0,0,0.25)"),
                    name=f"Global DTW" + (f" (cost={best_full:.3f})" if best_full is not None else "")
                ))

        fig.update_layout(
            title="Stage 2: chunk DP with discontinuity penalty + stitched path",
            xaxis_title="F2 frame j (global)",
            yaxis_title="F1 frame i (global)",
            template="plotly_white", width=900, height=700, legend=dict(orientation="h")
        )
        fig.update_xaxes(range=[0, tiled_result['C_shape'][1]], showgrid=False)
        fig.update_yaxes(range=[0, tiled_result['C_shape'][0]], showgrid=False, scaleanchor="x", scaleratio=1)
        fig.show()
   
    

    return {
        'C_chunk': C_chunk,
        'C_global': C_global,
        'L_chunk': L_chunk,
        'D_chunks': D_chunks,
        'L_chunks': L_chunks,
        'B_chunks': B_chunks,
        'ordered_blocks': ordered_blocks,
        'ordered_linear': ordered_linear,
        'avg_cost': float(avg_cost),
        'stitched_wp': stitched_wp,
        'stitched_raw_cost': stitched_raw_cost,
        'stitched_normalized': float(total_normalized_from_stitched) if np.isfinite(total_normalized_from_stitched) else None,
        'n_row': n_row, 'n_col': n_col
    }

In [17]:
## plot_parflex_g: helper to plot
import numpy as np

def plot_parflex_g(
    tiled_result,
    g_result,
    show_all_blocks=False,   # faint overlay of all block paths
    show_full_path=True,     # overlay the global full-matrix path if present
    title="Top-2 DP (chunk lattice) — chosen path"
):
    """
    Visualize the relevant (chosen) path from parflex_g_top2.
    - GOLD: stitched global path (g_result['stitched_wp'])
    - Colored lines: per-block chosen option (p1 or p2) in DP order
    - BLACK & THICK: global full-matrix DTW (if present)
    """
    try:
        import plotly.graph_objects as go
        import plotly.colors as pc
    except Exception as e:
        raise RuntimeError("Plotly is required for plotting.") from e

    blocks = tiled_result['blocks']
    if not blocks:
        raise ValueError("No blocks to plot.")
    by_idx = {(b['bi'], b['bj']): b for b in blocks}

    # palette and figure
    fig = go.Figure()
    palette = pc.qualitative.Plotly

    # 0) draw all block rectangles (grid)
    for b in blocks:
        (r0, r1), (c0, c1) = b['rows'], b['cols']
        fig.add_shape(
            type="rect",
            x0=c0, x1=c1, y0=r0, y1=r1,
            line=dict(color="rgba(140,140,140,0.5)", width=1, dash="dash"),
            fillcolor="rgba(0,0,0,0)",
            layer="below"
        )

    # 1) optional: overlay EVERY block's two local options faintly
    if show_all_blocks:
        for idx, b in enumerate(blocks):
            p1 = b.get('wp_global_p1') if 'wp_global_p1' in b else b.get('p1')
            p2 = b.get('wp_global_p2') if 'wp_global_p2' in b else b.get('p2')
            if p1 is not None and p1.size:
                fig.add_trace(go.Scatter(
                    x=p1[:, 1], y=p1[:, 0],
                    mode="lines",
                    line=dict(width=1, color="rgba(0,0,0,0.25)"),
                    name=f"blk({b['bi']},{b['bj']}) p1 (all)",
                    hoverinfo="skip",
                    showlegend=False
                ))
            if p2 is not None and p2.size:
                fig.add_trace(go.Scatter(
                    x=p2[:, 1], y=p2[:, 0],
                    mode="lines",
                    line=dict(width=1, color="rgba(0,0,0,0.25)"),
                    name=f"blk({b['bi']},{b['bj']}) p2 (all)",
                    hoverinfo="skip",
                    showlegend=False
                ))

    # 2) per-block CHOSEN option in DP order, colored and labeled
    centers_x, centers_y, labels = [], [], []
    for t, (bi, bj, k) in enumerate(g_result['chosen'], start=1):
        b = by_idx[(bi, bj)]
        wp = b.get('wp_global_p1') if k == 0 else b.get('wp_global_p2')
        if wp is None:
            wp = b.get('p1') if k == 0 else b.get('p2')
        if wp is None or not wp.size:
            continue
        color = palette[(t - 1) % len(palette)]
        fig.add_trace(go.Scatter(
            x=wp[:, 1], y=wp[:, 0],
            mode="lines",
            line=dict(width=2.5, color=color),
            name=f"chosen blk({bi},{bj}) p{k+1}",
            hovertemplate=f"blk({bi},{bj}) p{k+1}<extra></extra>",
            opacity=0.9
        ))
        (r0, r1), (c0, c1) = b['rows'], b['cols']
        centers_x.append((c0 + c1) / 2.0)
        centers_y.append((r0 + r1) / 2.0)
        labels.append(str(t))

    # 3) DP order markers at block centers
    if centers_x:
        fig.add_trace(go.Scatter(
            x=centers_x, y=centers_y,
            mode="markers+text",
            marker=dict(size=18, symbol="circle-open-dot"),
            text=labels, textposition="middle center",
            name="chunk order",
            hovertemplate="order=%{text}<br>center (j=%{x:.0f}, i=%{y:.0f})<extra></extra>"
        ))

    # 4) stitched (chosen) path — GOLD/YELLOW
    stitched_wp = g_result.get('stitched_wp', None)
    if stitched_wp is not None and stitched_wp.size:
        fig.add_trace(go.Scatter(
            x=stitched_wp[:, 1], y=stitched_wp[:, 0],
            mode="lines",
            line=dict(width=3.5, color="#FFD700"),  # gold
            name="stitched (chosen)"
        ))

    # 5) global full-matrix path — BLACK & THICK
    if show_full_path and isinstance(tiled_result.get('full_global'), dict):
        wp = tiled_result['full_global'].get('wp', None)
        best = tiled_result['full_global'].get('best_cost', None)
        if wp is not None:
            step = max(1, len(wp) // 5000)
            fig.add_trace(go.Scatter(
                x=wp[::step, 1], y=wp[::step, 0],
                mode="lines",
                line=dict(width=4, color="black"),
                name=f"Global DTW" + (f" (cost={best:.3f})" if best is not None else "")
            ))

    # axes / layout
    L1, L2 = tiled_result['C_shape']
    try:
        gi, gj, gk = g_result['chosen'][-1]
        sc = g_result['score'][gi, gj, gk]
        subtitle = f""
    except Exception:
        subtitle = ""
    fig.update_layout(
        title=title + subtitle,
        xaxis_title="F2 frame j (global)",
        yaxis_title="F1 frame i (global)",
        template="plotly_white",
        width=900, height=700,
        legend=dict(orientation="h")
    )
    fig.update_xaxes(range=[0, L2], showgrid=False)
    fig.update_yaxes(range=[0, L1], showgrid=False, scaleanchor="x", scaleratio=1)
    fig.show()


## Testing

In [18]:
number_tests = 2

In [41]:
## TEST parflex d:
from termcolor import colored
import random

directory = Path("/home/asharma/ttmp/Flex/FlexDTW/Chopin_Mazurkas_features/pre_20/Chopin_Op017No4")
files = list(directory.glob("*.npy"))
for i in range(number_tests):
    f1, f2 = random.sample(files, 2)   # pick 2 different random files
    print(colored(f"Test {i+1}/{number_tests}:", "green"))
    print(colored(f"Running parlfex_e on: {f1.name} vs {f2.name}", "blue"))

    F1 = np.load(f1)
    F2 = np.load(f2)

    C, result = align_system("flexdtw",F1,F2, 'auto')
    alignment = parflex_e(result, C)


[32mTest 1/2:[0m
[34mRunning parlfex_e on: Chopin_Op017No4_Czerny-Stefanska-1949_pid9086-07.npy vs Chopin_Op017No4_Kitain-1937_pid9163-02.npy[0m


[32mTest 2/2:[0m
[34mRunning parlfex_e on: Chopin_Op017No4_Rubinstein-1952_pid9075-13.npy vs Chopin_Op017No4_Smith-1975_pid9054-13.npy[0m


In [20]:
## TEST parflex g:
from termcolor import colored
import random

directory = Path("/home/asharma/ttmp/Flex/FlexDTW/Chopin_Mazurkas_features/matching/Chopin_Op017No4") # CHANGE to the right type of recs
files = list(directory.glob("*.npy"))
for i in range(number_tests):
    f1, f2 = random.sample(files, 2)   # pick 2 different random files
    print(colored(f"Test {i+1}/{number_tests}:", "green"))
    print(colored(f"Running parflex_g on: {f1.name} vs {f2.name}", "blue"))

    F1 = np.load(f1)
    F2 = np.load(f2)

    # build tiles with two local paths per block
    tiled = align_system_multiple_paths("flexdtw", F1, F2, L_block=2000, mask_radius=20, plot_overlay=False) # watch the paramS

    # run N×N×2 DP
    g_res = parflex_g(tiled, allow_diag=True, show_fig=False)

    # plot the relevant (chosen) path
    plot_parflex_g(tiled, g_res, show_all_blocks=True, show_full_path=True)


[32mTest 1/2:[0m
[34mRunning parflex_g on: Chopin_Op017No4_Nadelmann-1956_pid9165-04.npy vs Chopin_Op017No4_Weissenberg-1971_pid9052b-09.npy[0m


NameError: name 'parflex_g' is not defined

## Code to run current benchmarks

In [None]:
## OLD align_system function (for single full-matrix alignment)

def align_system(system, F1, F2, outfile):
    
    subseq = 'subseq' in system
    
    if system == 'flexdtw':
        L1 = F1.shape[1]
        L2 = F2.shape[1]
        buffer = min(L1, L2) * (1 - (1 - other_params[system]['beta']) * min(L1,L2) / max(L1, L2))
        C = 1 - FlexDTW.L2norm(F1).T @ FlexDTW.L2norm(F2) # cos distance metric
        best_cost, wp, debug = FlexDTW.flexdtw(C, steps=steps[system], weights=weights[system], buffer=buffer)
    elif system == 'nwtw':
        downsample = 1
        C = 1 - NWTW.L2norm(F1)[:,0::downsample].T @ NWTW.L2norm(F2)[:,0::downsample] # cos distance metric
        optcost, wp, D, B = NWTW.NWTW_faster(C, gamma=0.346)
    else:
        downsample = 1
        if subseq and (F2.shape[1] < F1.shape[1]):
            C = 1 - DTW.L2norm(F2)[:,0::downsample].T @ DTW.L2norm(F1)[:,0::downsample] # cos distance metric
            wp = DTW.alignDTW(C, steps=steps[system], weights=weights[system], downsample=downsample, outfile=outfile, subseq=subseq)
            wp = wp[::-1,:]
        else:
            C = 1 - DTW.L2norm(F1)[:,0::downsample].T @ DTW.L2norm(F2)[:,0::downsample] # cos distance metric
            wp = DTW.alignDTW(C, steps=steps[system], weights=weights[system], downsample=downsample, outfile=outfile, subseq=subseq)
            
    if wp is not None:
        pickle.dump(wp, open(outfile, 'wb'))

In [None]:
def run_all_benchmarks(outdir):
    parts_batch = []
    queryids = []
    with open(QUERY_LIST, 'r') as f:
        for line in f:
            parts = line.strip().split(' ')
            assert len(parts) == 2
            queryid = os.path.basename(parts[0]) + '__' + os.path.basename(parts[1])
            
            if 'Czerny-Stefanska-1949_pid9086' in queryid:
                continue
            
            parts_batch.append(parts)
            queryids.append(queryid)
            
    for benchmark in tqdm(BENCHMARKS):
#         for i in range(len(parts_batch)):
#             run_benchmark(benchmark, FEAT_DIRS[benchmark][0], FEAT_DIRS[benchmark][1], parts_batch[i], outdir, queryids[i])
        run_benchmark_batch(benchmark, FEAT_DIRS[benchmark][0], FEAT_DIRS[benchmark][1], parts_batch, outdir, queryids, n_cores=4)

In [None]:
def run_benchmark_batch(benchmark, featdir1, featdir2, parts_batch, outdir, queryids, n_cores):
    inputs = []
    assert len(parts_batch) == len(queryids)
    count = 30
    for i in range(len(parts_batch)):
        featfile1 = (featdir1 / parts_batch[i][0]).with_suffix('.npy')
        featfile2 = (featdir2 / parts_batch[i][1]).with_suffix('.npy')
        
        F1 = np.load(featfile1)
        F2 = np.load(featfile2)

        for system in SYSTEMS:
            # only compute alignment if this hypothesis file doesn't already exist
            outfile = get_outfile(outdir, benchmark, system, queryids[i]) 
            if not os.path.isfile(outfile):  
                inputs.append((system, F1, F2, outfile))
                if count%30 ==0: 
                    print("Aligning",count, len(parts_batch), outfile)

    # process files in parallel
    pool = multiprocessing.Pool(processes = multiprocessing.cpu_count()-1)
    pool.starmap(align_system, inputs)
    count+=1
    
    
    return

In [None]:
def run_benchmark(benchmark, featdir1, featdir2, parts, outdir, queryid):
    featfile1 = (featdir1 / parts[0]).with_suffix('.npy')
    featfile2 = (featdir2 / parts[1]).with_suffix('.npy')

    F1 = np.load(featfile1)
    F2 = np.load(featfile2)
        
    # run all baselines
    count = 20
    for system in SYSTEMS:
        
        # only compute alignment if this hypothesis file doesn't already exist
        outfile = get_outfile(outdir, benchmark, system, queryids[i])
        print(outfile, "not computed yet")
        if not os.path.isfile(outfile): 
            print(system, featfile1, featfile2, outfile)
            if count %20==0:  
                print("Aligning",count, outfile)
            align_system(system, F1, F2, outfile)
        count+=1

In [None]:
# outdir = Path(f'experiments_{DATASET}/{VERSION}')
# run_all_benchmarks(outdir)