# Sequence Anomaly Detection (LSTM/GRU)
Synthetic behavior-like session sequences with padding/masking. Fast to run on CPU; no external data needed.

In [None]:
from pathlib import Path
import sys

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

project_root = Path('..').resolve()
src_path = project_root / 'src'
if str(src_path) not in sys.path:
    sys.path.append(str(src_path))

from uais.sequence.build_sequences import build_sequences, pad_sequences
from uais.sequence.train_lstm import train_lstm_classifier, predict_lstm
from uais.sequence.train_gru import train_gru_classifier, predict_gru
from uais.sequence.evaluate_sequence import evaluate_sequence_predictions
from uais.explainability.sequence_explainer import sequence_saliency

np.random.seed(42)
import torch
_ = torch.manual_seed(42)

print('Project root:', project_root)
print('Using torch', torch.__version__)


In [None]:
# Generate synthetic CERT-like session sequences (variable length)
rng = np.random.default_rng(7)
num_sessions = 320
records = []
for session_id in range(num_sessions):
    length = int(rng.integers(8, 30))
    base_ts = np.datetime64('2024-01-01') + np.timedelta64(int(rng.integers(0, 21)), 'D')
    is_anomaly = bool(rng.random() < 0.2)
    for step in range(length):
        ts = base_ts + np.timedelta64(int(step * 5 + rng.integers(0, 4)), 'm')
        event_code = int(rng.integers(0, 3))
        bytes_out = float(rng.normal(120, 25))
        failed = int(rng.binomial(1, 0.06))
        if is_anomaly:
            event_code = int(rng.integers(2, 4))
            bytes_out += float(rng.normal(150, 40))
            failed = int(rng.binomial(1, 0.25))
        records.append({
            'session_id': f's{session_id}',
            'timestamp': pd.Timestamp(ts),
            'event_code': event_code,
            'bytes_out': max(bytes_out, 0.0),
            'failed_login': failed,
            'label': int(is_anomaly),
        })

seq_df = pd.DataFrame(records)
seq_df = seq_df.sort_values(['session_id', 'timestamp']).reset_index(drop=True)
print(seq_df.head())
print('Events:', len(seq_df), 'Sequences:', seq_df['session_id'].nunique())
print('Positive sequences:', seq_df.groupby('session_id')['label'].max().sum())


In [None]:
# Build padded tensors + mask for the sequence models
sequences, labels = build_sequences(
    seq_df,
    id_column='session_id',
    time_column='timestamp',
    target_column='label',
)
padded, mask = pad_sequences(sequences, max_len=40)
labels = np.asarray(labels)

print('Padded shape:', padded.shape)
print('Mask shape:', mask.shape)
print('Feature dim:', padded.shape[-1])
print('Positive ratio:', labels.mean())


In [None]:
# Train/test split and LSTM classifier
X_train, X_test, mask_train, mask_test, y_train, y_test = train_test_split(
    padded,
    mask,
    labels,
    test_size=0.25,
    stratify=labels,
    random_state=42,
)

config = {'sequence': {'hidden_dim': 32, 'batch_size': 32, 'epochs': 6, 'lr': 1e-3}}

lstm_model, lstm_loss = train_lstm_classifier(X_train, mask_train, y_train, config)
lstm_scores = predict_lstm(lstm_model, X_test, mask_test)
lstm_metrics = evaluate_sequence_predictions(y_test, lstm_scores)

print('LSTM train loss:', round(lstm_loss, 4))
print('LSTM metrics:')
for k, v in lstm_metrics.items():
    print(f"  {k}: {v:.4f}")


In [None]:
# Lightweight GRU comparison (same hyperparams)
gru_model, gru_loss = train_gru_classifier(X_train, mask_train, y_train, config)
gru_scores = predict_gru(gru_model, X_test, mask_test)
gru_metrics = evaluate_sequence_predictions(y_test, gru_scores)

print('GRU train loss:', round(gru_loss, 4))
print('GRU metrics:')
for k, v in gru_metrics.items():
    print(f"  {k}: {v:.4f}")


In [None]:
# Simple saliency over time steps for one test sequence
example_idx = 0
saliency = sequence_saliency(X_test[example_idx:example_idx+1], mask_test[example_idx:example_idx+1])
example_score = float(lstm_scores[example_idx])
print('Example score:', round(example_score, 4))
print('Saliency (step -> avg magnitude):')
for step, score in saliency.items():
    if mask_test[example_idx, step] > 0:
        print(f"  t={step}: {score:.4f}")
