In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from tslearn.utils import to_time_series_dataset
import wandb
from datetime import timedelta
import matplotlib.pyplot as plt
import joblib
from tslearn.neighbors import KNeighborsTimeSeries, KNeighborsTimeSeriesClassifier
from tslearn.barycenters import dtw_barycenter_averaging
from tslearn.metrics import dtw
from scripts.NativeGuide.Native_Guide_DBA_lib_1_1 import *

In [None]:
label_file = '../../data/raw/class_labels.csv'
data_dir = '../../data/raw/data/'
cols_to_keep = ['p3_flux_ic', 'p5_flux_ic', 'p7_flux_ic', 'long']

label_map = load_labels(label_file)
X, y, meta = load_mvts_and_labels(data_dir, label_map, cols_to_keep=cols_to_keep, exclude_files=['2005-09-07_14-25.csv'])

X_train, X_val, X_test, y_train, y_val, y_test, meta_train, meta_val, meta_test = train_val_test_split(X, y, meta)
X_train_3D = convert_to_3d_numpy(X_train, cols_to_keep)
X_val_3D   = convert_to_3d_numpy(X_val, cols_to_keep)
X_test_3D  = convert_to_3d_numpy(X_test, cols_to_keep)

model = joblib.load('../../models/KNN_1_TS__classifier_v1.3.pkl')

In [None]:
def compute_importance(query, model, window_sizes=[6, 30, 60, 120, 180], mask_strategy="mean"):
    """
    Compute importance scores for subsequences of a multivariate time series
    by perturbing windows and measuring prediction probability drop.
    """
    n_channels, n_timestamps = query.shape
    orig_probs = model.predict_proba(query[np.newaxis, :, :])[0]
    orig_class = np.argmax(orig_probs)
    orig_prob = orig_probs[orig_class]

    importance_dict = {}

    for W in window_sizes:
        window_scores = []
        # Slide window across with stride = W/2 (overlap)
        for start in range(0, n_timestamps - W + 1, max(1, W // 2)):
            end = start + W
            perturbed = query.copy()

            if mask_strategy == "mean":
                fill_vals = np.mean(query[:, start:end], axis=1, keepdims=True)
                perturbed[:, start:end] = fill_vals
            elif mask_strategy == "zero":
                perturbed[:, start:end] = 0.0
            elif mask_strategy == "noise":
                np.random.seed(0)  # fix for reproducibility
                perturbed[:, start:end] = np.random.normal(
                    loc=0.0, scale=np.std(query[:, start:end]), size=perturbed[:, start:end].shape
                )
            else:
                raise ValueError(f"Unknown mask_strategy: {mask_strategy}")

            new_prob = model.predict_proba(perturbed[np.newaxis, :, :])[0][orig_class]
            score = orig_prob - new_prob  # drop in probability → importance
            window_scores.append((start, end, score))

        importance_dict[W] = window_scores

    return importance_dict


def grid_search_importance(X_test_3D, model, query_indices,
                           window_sizes=[6,30,60,120,180],
                           mask_strategies=["mean","zero","noise"]):
    """
    Run grid search over window sizes and mask strategies for multiple queries.
    Returns a DataFrame with aggregated results.
    """
    results = []

    for qid in tqdm(query_indices, desc="Grid Search Queries"):
        query = X_test_3D[qid]

        for mask in mask_strategies:
            importance_dict = compute_importance(query, model, window_sizes, mask)

            for W, window_scores in importance_dict.items():
                for (start, end, score) in window_scores:
                    results.append({
                        "query_id": qid,
                        "mask_strategy": mask,
                        "window_size": W,
                        "start": start,
                        "end": end,
                        "importance_score": score
                    })

    df = pd.DataFrame(results)
    return df


In [None]:
# Pick 100 random queries from test set
np.random.seed(42)
query_indices = np.random.choice(len(X_test_3D), size=100, replace=False)

df_importance = grid_search_importance(
    X_test_3D, model, query_indices,
    window_sizes=[6,30,60,120,180],
    mask_strategies=["mean","zero","noise"]
)

df_importance.head()
