In [57]:
import logging
import os
import sys
import time
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
from tqdm import tqdm

from spectrum.config import WINDOW_SIZE

sys.path.append(os.path.abspath("../.."))
from spectrum.utils import set_random_state
from spectrum.models import LSTM

logging.basicConfig(level=logging.INFO)

warnings.filterwarnings("ignore")

sns.set_theme(style="whitegrid")
plt.rcParams.update(
    {
        "axes.edgecolor": "0.3",
        "axes.linewidth": 0.8,
        "font.size": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "axes.titleweight": "bold",
        "legend.fontsize": 10,
        "figure.dpi": 120,
        "legend.frameon": False,
    }
)
set_random_state()

In [58]:
selected_ids = [4, 17, 33]

results_dir = "../../results/models/lstm"
os.makedirs(results_dir, exist_ok=True)


def find_best_threshold(scores, true_labels, thresholds=None):
    if thresholds is None:
        # use percentiles as candidate thresholds
        thresholds = [np.percentile(scores, p) for p in range(50, 100, 1)]
        # add some extra threshold points
        thresholds.extend([np.percentile(scores, p) for p in [99.5, 99.9]])

    best_f1 = 0
    best_threshold = thresholds[0]
    best_metrics = {}

    for threshold in thresholds:
        pred_labels = (scores > threshold).astype(int)

        # calculate confusion matrix
        TP = ((true_labels == 1) & (pred_labels == 1)).sum()
        FP = ((true_labels == 0) & (pred_labels == 1)).sum()
        TN = ((true_labels == 0) & (pred_labels == 0)).sum()
        FN = ((true_labels == 1) & (pred_labels == 0)).sum()

        # calculate metrics
        accuracy = (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) > 0 else 0
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        fnr = FN / (FN + TP) if (FN + TP) > 0 else 0
        fpr = FP / (FP + TN) if (FP + TN) > 0 else 0
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
            best_metrics = {
                'threshold': threshold,
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'fnr': fnr,
                'fpr': fpr,
                'tp': TP,
                'fp': FP,
                'tn': TN,
                'fn': FN
            }

    return best_threshold, best_metrics


def process_single_id(data_id):
    print(f"processing: {data_id}")

    train_data = pl.read_csv(f"../../datasets/Yahoo/train/A1/{data_id}.csv")
    test_data = pl.read_csv(f"../../datasets/Yahoo/test/A1/{data_id}.csv")
    complete_data = pl.read_csv(f"../../datasets/Yahoo/data/A1Benchmark/{data_id}.csv")

    print("  training model...")
    model = LSTM()
    start_time = time.time()
    model.fit(train_data["value"])
    training_time = time.time() - start_time

    print("  anomaly detection...")
    start_time = time.time()
    scores = model.predict(test_data["value"])
    scoring_time = time.time() - start_time

    train_len = len(train_data)
    scores_array = scores

    test_true_labels = complete_data["label"][train_len + WINDOW_SIZE:].to_numpy()

    print("  finding best threshold...")
    best_threshold, best_metrics = find_best_threshold(scores_array, test_true_labels)

    complete_values = complete_data["value"].to_numpy()
    complete_labels = complete_data["label"].to_numpy()
    complete_timestamps = complete_data["timestamp"] if "timestamp" in complete_data.columns else range(
        len(complete_values))

    complete_predictions = np.zeros(len(complete_values))
    complete_anomaly_scores = np.zeros(len(complete_values))

    complete_predictions[train_len + WINDOW_SIZE:] = (scores_array > best_threshold).astype(int)
    complete_anomaly_scores[train_len + WINDOW_SIZE:] = scores_array

    result_df = pd.DataFrame({
        'timestamp': complete_timestamps,
        'value': complete_values,
        'label': complete_labels,
        'predicted': complete_predictions,
        'anomaly_score': complete_anomaly_scores
    })

    output_file = os.path.join(results_dir, f"{data_id}.csv")
    result_df.to_csv(output_file, index=False)
    print(f"  results saved to: {output_file}")

    return {
        'id': data_id,
        'training_time': training_time,
        'testing_time': scoring_time,
        'total_time': training_time + scoring_time,
        'train_samples': len(train_data),
        'test_samples': min_len,
        'best_threshold': best_threshold,
        **best_metrics
    }


all_results = []
print(f"processing {len(selected_ids)} datasets...")

for data_id in tqdm(selected_ids, desc="processing"):
    try:
        result = process_single_id(data_id)
        all_results.append(result)

        print(f"  ID {data_id} completed:")
        print(f"    best_threshold: {result['best_threshold']:.4f}")
        print(f"    f1: {result['f1']:.4f}")
        print(f"    precision: {result['precision']:.4f}")
        print(f"    recall: {result['recall']:.4f}")
        print(f"    accuracy: {result['accuracy']:.4f}")

    except Exception as e:
        print(f"  processing {data_id} failed: {str(e)}")
        import traceback

        traceback.print_exc()
        continue

if all_results:
    summary_df = pd.DataFrame(all_results)
    summary_file = os.path.join(results_dir, "lstm.csv")
    summary_df.to_csv(summary_file, index=False)
    print(f"summary results saved to: {summary_file}")

    print("\n" + "=" * 80)
    print("LSTM anomaly detection results")
    print("=" * 80)
    print(f"processed {len(all_results)} datasets")
    print(f"average F1: {summary_df['f1'].mean():.4f} ± {summary_df['f1'].std():.4f}")
    print(f"average precision: {summary_df['precision'].mean():.4f} ± {summary_df['precision'].std():.4f}")
    print(f"average recall: {summary_df['recall'].mean():.4f} ± {summary_df['recall'].std():.4f}")
    print(f"average accuracy: {summary_df['accuracy'].mean():.4f} ± {summary_df['accuracy'].std():.4f}")
    print(f"average training time: {summary_df['training_time'].mean():.2f}s")
    print(f"average scoring time: {summary_df['testing_time'].mean():.2f}s")
    print("=" * 80)

    print("details:")
    display_cols = ['id', 'f1', 'precision', 'recall', 'accuracy', 'best_threshold', 'tp', 'fp', 'tn', 'fn']
    print(summary_df[display_cols].round(4))
else:
    print("no results")

processing 3 datasets...


processing:   0%|          | 0/3 [00:00<?, ?it/s]

processing: 4
  training model...


processing:  33%|███▎      | 1/3 [00:06<00:12,  6.02s/it]

  anomaly detection...
  finding best threshold...
  results saved to: ../../results/models/lstm/4.csv
  ID 4 completed:
    best_threshold: 190.0937
    f1: 0.7826
    precision: 0.6429
    recall: 1.0000
    accuracy: 0.9926
processing: 17
  training model...


processing:  67%|██████▋   | 2/3 [00:07<00:03,  3.07s/it]

  anomaly detection...
  finding best threshold...
  results saved to: ../../results/models/lstm/17.csv
  ID 17 completed:
    best_threshold: 91.3125
    f1: 0.9896
    precision: 1.0000
    recall: 0.9795
    accuracy: 0.9956
processing: 33
  training model...


processing: 100%|██████████| 3/3 [00:11<00:00,  3.75s/it]

  anomaly detection...
  finding best threshold...
  results saved to: ../../results/models/lstm/33.csv
  ID 33 completed:
    best_threshold: 4842.6587
    f1: 1.0000
    precision: 1.0000
    recall: 1.0000
    accuracy: 1.0000
summary results saved to: ../../results/models/lstm/lstm.csv

LSTM anomaly detection results
processed 3 datasets
average F1: 0.9241 ± 0.1226
average precision: 0.8810 ± 0.2062
average recall: 0.9932 ± 0.0119
average accuracy: 0.9961 ± 0.0037
average training time: 3.69s
average scoring time: 0.04s
details:
   id      f1  precision  recall  accuracy  best_threshold   tp  fp   tn  fn
0   4  0.7826     0.6429  1.0000    0.9926      190.093704    9   5  666   0
1  17  0.9896     1.0000  0.9795    0.9956       91.312500  143   0  534   3
2  33  1.0000     1.0000  1.0000    1.0000     4842.658691    1   0  687   0



