In [5]:
# ======================= LSTM PAIRWISE AGENTS + MQTT LIVE =======================
import os
import json
import time
from collections import deque
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report

import tensorflow as tf
from tensorflow.keras import Sequential, layers
from tensorflow.keras.layers import LSTM, Dropout

import paho.mqtt.client as mqtt
from paho.mqtt.client import CallbackAPIVersion

# ---------------------------------------------------------------------
# PATHS
# ---------------------------------------------------------------------
DATA_DIR = r"C:\Users\Muckbul\Desktop\agenticAI"
ARTIFACTS_ROOT = DATA_DIR  # where models + scaler + feature_names go

# ---------------------------------------------------------------------
# MQTT CONFIG
# ---------------------------------------------------------------------
MQTT_BROKER = "localhost"
MQTT_PORT = 1883

TOPIC_BASE = "fault_detection/report/pairwise"
LIVE_FEATURE_TOPIC = "fault_detection/live/features"
LIVE_RESULT_TOPIC = "fault_detection/live/result"

# ---------------------------------------------------------------------
# TRAINING / MODEL CONFIG
# ---------------------------------------------------------------------
TIME_STEPS = 10
SEED, TEST_SIZE, VAL_SIZE = 42, 0.2, 0.2
BATCH, EPOCHS, LR = 128, 20, 1e-4

RUN_TRAINING = True   # set to False when models are already trained

# use only first N rows from each OM file for training
MAX_ROWS_PER_FILE = 500

# if you want to only use data after a certain time, set a float here;
# otherwise keep as None (your OM1..OM8 already contain faults from start).
FAULT_START_TIME: Optional[float] = None  # e.g. 500.0 if needed

rng = np.random.RandomState(SEED)
tf.random.set_seed(SEED)

# CLASS NAMES (OM0..OM8)
CLASS_NAMES = [f"OM{i}" for i in range(0, 9)]  # 9 classes total
# All unordered pairs (i < j)
PAIR_LIST: List[Tuple[str, str]] = [
    (CLASS_NAMES[i], CLASS_NAMES[j])
    for i in range(len(CLASS_NAMES))
    for j in range(i + 1, len(CLASS_NAMES))
]


# ---------------------------------------------------------------------
# BASIC UTILITIES
# ---------------------------------------------------------------------
def load_numeric(
    path: str,
    max_rows: Optional[int] = None,
    fault_start_time: Optional[float] = None,
) -> pd.DataFrame:
    """
    Load numeric columns from Excel/CSV.
    If 'Time' column exists and fault_start_time is not None, keep only
    rows with Time >= fault_start_time. Then keep only numeric columns.
    Optionally limit to 'max_rows' rows (after filtering).
    """
    ext = os.path.splitext(path)[1].lower()
    if ext in (".xlsx", ".xls"):
        df = pd.read_excel(path)
    elif ext == ".csv":
        df = pd.read_csv(path)
    else:
        raise ValueError(f"Unsupported file type: {ext}")

    # Force-convert Time to numeric (if present)
    if "Time" in df.columns:
        df["Time"] = pd.to_numeric(df["Time"], errors="coerce")
        df = df.dropna(subset=["Time"])
        if fault_start_time is not None:
            df = df[df["Time"] >= fault_start_time].copy()

    # Numeric columns only
    num = df.select_dtypes(include=[np.number]).copy()
    num = num.replace([np.inf, -np.inf], np.nan).ffill().bfill()

    if max_rows is not None and len(num) > max_rows:
        num = num.iloc[:max_rows].copy()

    return num


def find_file(base: str) -> Optional[str]:
    p_xlsx = os.path.join(DATA_DIR, f"{base}.xlsx")
    p_csv = os.path.join(DATA_DIR, f"{base}.csv")
    if os.path.exists(p_xlsx):
        return p_xlsx
    if os.path.exists(p_csv):
        return p_csv
    return None


def create_sequences(X: np.ndarray, y: np.ndarray, time_steps: int = 10, step: int = 1):
    Xs, ys = [], []
    for i in range(0, len(X) - time_steps, step):
        Xs.append(X[i: i + time_steps])
        ys.append(y[i + time_steps - 1])
    return np.array(Xs), np.array(ys)


def oversample_indices(idx_array: np.ndarray, n_target: int) -> np.ndarray:
    if len(idx_array) == 0:
        return idx_array
    if n_target <= len(idx_array):
        return rng.choice(idx_array, size=n_target, replace=False)
    extra = rng.choice(idx_array, size=n_target - len(idx_array), replace=True)
    return np.concatenate([idx_array, extra], axis=0)


def make_topic(exp_name: str) -> str:
    safe = exp_name.replace(" ", "_").replace("(", "").replace(")", "")
    return f"{TOPIC_BASE}/{safe}"


# ---------------------------------------------------------------------
# MQTT CALLBACKS (for training summary only)
# ---------------------------------------------------------------------
def on_connect(client, userdata, flags, reason_code, properties=None):
    print(f"[MQTT] Connected rc={reason_code}")


def on_disconnect(client, userdata, flags, reason_code, properties=None):
    print(f"[MQTT] Disconnected rc={reason_code}")
    # For paho v2, reason_code may be an object; treat non-zero as error
    try:
        rc_val = int(reason_code)
    except Exception:
        rc_val = getattr(reason_code, "value", 0)

    if rc_val != 0:
        print("[MQTT] Trying to reconnect...")
        try:
            client.reconnect()
        except Exception as e:
            print(f"[MQTT] Reconnect failed: {e}")



def on_message(client, userdata, msg):
    # Not used here; keep for completeness.
    pass


def publish_mqtt_report(client, exp_name: str, accuracy: float, report_dict: dict):
    topic = make_topic(exp_name)
    payload = {
        "agent": exp_name,
        "model_type": "LSTM_pairwise",
        "overall_test_accuracy": round(float(accuracy), 4),
        "weighted_f1": round(float(report_dict["weighted avg"]["f1-score"]), 4),
        "timestamp": pd.Timestamp.now().isoformat(),
    }
    msg = json.dumps(payload, ensure_ascii=False)
    print(f"[MQTT] Publish summary -> {topic}")
    if client:
        client.publish(topic, msg, qos=1)


# ---------------------------------------------------------------------
# GLOBAL SCALER (shared by all pairwise agents)
# ---------------------------------------------------------------------
def build_global_scaler() -> Tuple[StandardScaler, List[str]]:
    """
    Load all OM0..OM8 data (limited to MAX_ROWS_PER_FILE) and fit one
    StandardScaler on the concatenated dataset. Save mean/scale and
    feature_names to ARTIFACTS_ROOT.
    """
    X_all_parts = []
    feature_names: Optional[List[str]] = None

    for cls in CLASS_NAMES:
        p = find_file(cls)
        if p is None:
            print(f"[GLOBAL SCALER] WARNING: file for {cls} not found; skipping.")
            continue
        df = load_numeric(p, max_rows=MAX_ROWS_PER_FILE, fault_start_time=FAULT_START_TIME)
        if df.empty:
            print(f"[GLOBAL SCALER] WARNING: {cls} loaded but empty; skipping.")
            continue

        if feature_names is None:
            feature_names = df.columns.tolist()
        else:
            df = df.reindex(columns=feature_names)

        X_all_parts.append(df.to_numpy())

    if not X_all_parts or feature_names is None:
        raise RuntimeError("[GLOBAL SCALER] No data loaded from OM files.")

    X_all = np.concatenate(X_all_parts, axis=0)
    X_all = np.nan_to_num(X_all, nan=0.0)

    scaler = StandardScaler().fit(X_all)

    scaler_obj = {"mean": scaler.mean_.tolist(), "scale": scaler.scale_.tolist()}
    os.makedirs(ARTIFACTS_ROOT, exist_ok=True)
    with open(os.path.join(ARTIFACTS_ROOT, "scaler_mean_std.json"), "w") as f:
        json.dump(scaler_obj, f)
    with open(os.path.join(ARTIFACTS_ROOT, "feature_names.json"), "w") as f:
        json.dump(feature_names, f)

    print("[GLOBAL SCALER] Fitted scaler on all classes.")
    print(f"[GLOBAL SCALER] Features: {feature_names}")

    return scaler, feature_names


# ---------------------------------------------------------------------
# TRAIN ONE PAIRWISE AGENT: (cls_a vs cls_b)
# ---------------------------------------------------------------------
def train_pair_agent(
    client,
    scaler: StandardScaler,
    feature_names: List[str],
    cls_a: str,
    cls_b: str,
) -> Optional[Dict[str, Any]]:
    """
    Train LSTM for pair (cls_a vs cls_b).
    Label y=0 for cls_a, y=1 for cls_b.
    """
    exp_name = f"Agent {cls_a}_vs_{cls_b}"
    print("\n" + "=" * 80)
    print(f"TRAINING: {exp_name}")
    print("=" * 80)

    # ---- Load data for both classes ----
    X_parts, y_parts = [], []

    for cls, label in [(cls_a, 0), (cls_b, 1)]:
        p = find_file(cls)
        if p is None:
            print(f"[PAIR {cls_a} vs {cls_b}] File for {cls} not found; skipping pair.")
            return None
        df = load_numeric(p, max_rows=MAX_ROWS_PER_FILE, fault_start_time=FAULT_START_TIME)
        if df.empty:
            print(f"[PAIR {cls_a} vs {cls_b}] Empty data for {cls}; skipping pair.")
            return None

        # align columns
        df = df.reindex(columns=feature_names)
        X = df.to_numpy()
        X = np.nan_to_num(X, nan=0.0)

        # use global scaler
        X_scaled = scaler.transform(X)

        X_parts.append(X_scaled)
        y_parts.append(np.full(len(df), label, dtype=int))

    X_all = np.concatenate(X_parts, axis=0)
    y_all = np.concatenate(y_parts, axis=0)

    # ---- Create sequences ----
    X_seq, y_seq = create_sequences(X_all, y_all, time_steps=TIME_STEPS)
    print(
        f"[PAIR {cls_a} vs {cls_b}] X_seq={X_seq.shape}, "
        f"y_seq={y_seq.shape}, #0={(y_seq==0).sum()}, #1={(y_seq==1).sum()}"
    )
    if len(X_seq) < 2 * TIME_STEPS:
        print(f"[PAIR {cls_a} vs {cls_b}] Not enough sequences; skipping.")
        return None

    # ---- Split train/val/test ----
    X_tr, X_te, y_tr, y_te = train_test_split(
        X_seq, y_seq, test_size=TEST_SIZE, random_state=SEED, stratify=y_seq
    )
    X_tr, X_val, y_tr, y_val = train_test_split(
        X_tr, y_tr, test_size=VAL_SIZE, random_state=SEED, stratify=y_tr
    )

    # ---- Balance classes by oversampling ----
    idx = np.arange(len(y_tr))
    idx_a = idx[y_tr == 0]
    idx_b = idx[y_tr == 1]
    target = max(len(idx_a), len(idx_b))
    idx_a_bal = oversample_indices(idx_a, target)
    idx_b_bal = oversample_indices(idx_b, target)
    idx_final = np.concatenate([idx_a_bal, idx_b_bal])
    rng.shuffle(idx_final)

    X_tr_bal = X_tr[idx_final]
    y_tr_bal = y_tr[idx_final]

    print(
        f"[PAIR {cls_a} vs {cls_b}] After balancing: "
        f"#0={(y_tr_bal==0).sum()}, #1={(y_tr_bal==1).sum()}"
    )

    # ---- Build & train LSTM ----
    n_features = X_tr_bal.shape[2]
    model = Sequential(
        [
            layers.Input(shape=(TIME_STEPS, n_features)),
            LSTM(64, return_sequences=True),
            Dropout(0.2),
            LSTM(32, return_sequences=False),
            Dropout(0.2),
            layers.Dense(32, activation="relu"),
            layers.Dense(1, activation="sigmoid"),
        ],
        name=f"LSTM_{cls_a}_vs_{cls_b}",
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LR),
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )

    print(f"[PAIR {cls_a} vs {cls_b}] Training LSTM...")
    model.fit(
        X_tr_bal,
        y_tr_bal,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH,
        verbose=1,
    )

    # ---- Evaluate ----
    probs_te = model.predict(X_te, verbose=0).flatten()
    y_pred_te = (probs_te > 0.5).astype(int)

    report = classification_report(
        y_te,
        y_pred_te,
        target_names=[cls_a, cls_b],
        output_dict=True,
        zero_division=0,
    )
    test_acc = report["accuracy"]
    print(f"[PAIR {cls_a} vs {cls_b}] Test accuracy: {test_acc:.4f}")

    # ---- Save model ----
    safe_name = exp_name.replace(" ", "_").replace("(", "").replace(")", "")
    save_dir = os.path.join(ARTIFACTS_ROOT, f"{safe_name}_artifacts")
    os.makedirs(save_dir, exist_ok=True)
    model.save(os.path.join(save_dir, "model_saved.keras"))

    publish_mqtt_report(client, exp_name, test_acc, report)

    return {"Pair": f"{cls_a}_vs_{cls_b}", "TestAcc": test_acc}


# ---------------------------------------------------------------------
# LIVE INFERENCE (pairwise voting)
# ---------------------------------------------------------------------
class LiveInference:
    def __init__(self):
        self.feature_names: List[str] = []
        self.scaler_mean: np.ndarray = None
        self.scaler_scale: np.ndarray = None

        # list of (cls_a, cls_b, model)
        self.pair_models: List[Tuple[str, str, tf.keras.Model]] = []

        self.history_buffer = deque(maxlen=TIME_STEPS)
        self.last_time: Optional[float] = None
        self.have_seen_any: bool = False

    def load_artifacts(self) -> bool:
        scaler_path = os.path.join(ARTIFACTS_ROOT, "scaler_mean_std.json")
        feat_path = os.path.join(ARTIFACTS_ROOT, "feature_names.json")

        if not (os.path.exists(scaler_path) and os.path.exists(feat_path)):
            print("[LIVE] Shared scaler/feature_names not found.")
            return False

        with open(scaler_path, "r") as f:
            scaler_obj = json.load(f)
        with open(feat_path, "r") as f:
            self.feature_names = json.load(f)

        self.scaler_mean = np.array(scaler_obj["mean"], dtype=float)
        self.scaler_scale = np.array(scaler_obj["scale"], dtype=float)
        self.scaler_scale[self.scaler_scale == 0] = 1.0

        # load all available pairwise models
        for cls_a, cls_b in PAIR_LIST:
            exp_name = f"Agent {cls_a}_vs_{cls_b}"
            folder = os.path.join(
                ARTIFACTS_ROOT,
                f"{exp_name.replace(' ', '_').replace('(', '').replace(')', '')}_artifacts",
            )
            model_path = os.path.join(folder, "model_saved.keras")
            if os.path.exists(model_path):
                print(f"[LIVE] Loading model for pair {cls_a} vs {cls_b}")
                model = tf.keras.models.load_model(model_path)
                self.pair_models.append((cls_a, cls_b, model))

        if not self.pair_models:
            print("[LIVE] No pairwise models loaded.")
            return False

        print(f"[LIVE] Using features ({len(self.feature_names)}): {self.feature_names}")
        print(f"[LIVE] Loaded {len(self.pair_models)} pairwise agents.")
        return True

    def reset_buffer(self):
        self.history_buffer.clear()
        self.last_time = None
        self.have_seen_any = False
        print("[LIVE] Buffer reset (new simulation).")

    def preprocess_single_vector(self, feat_dict: Dict[str, Any]) -> np.ndarray:
        x = np.zeros(len(self.feature_names), dtype=float)
        for i, name in enumerate(self.feature_names):
            try:
                x[i] = float(feat_dict.get(name, 0.0))
            except Exception:
                x[i] = 0.0
        return (x - self.scaler_mean) / self.scaler_scale

    def predict_fault(self) -> Dict[str, Any]:
        seq_data = np.array(self.history_buffer)  # (T,F)
        X_in = np.expand_dims(seq_data, axis=0)   # (1,T,F)

        votes = {cls: 0 for cls in CLASS_NAMES}
        score_sums = {cls: 0.0 for cls in CLASS_NAMES}

        for cls_a, cls_b, model in self.pair_models:
            prob_b = float(model.predict(X_in, verbose=0)[0, 0])
            prob_a = 1.0 - prob_b

            if prob_b >= 0.5:
                winner = cls_b
                prob = prob_b
            else:
                winner = cls_a
                prob = prob_a

            votes[winner] += 1
            score_sums[winner] += prob

        # pick class with maximum votes; break ties with summed score
        best_cls = None
        best_votes = -1
        best_score = -1.0
        for cls in CLASS_NAMES:
            v = votes[cls]
            s = score_sums[cls]
            if v > best_votes or (v == best_votes and s > best_score):
                best_cls = cls
                best_votes = v
                best_score = s

        winner = best_cls if best_cls is not None else "OM0"
        # normalize probability: average of winner's scores
        if best_votes > 0:
            winner_prob = best_score / best_votes
        else:
            winner_prob = 0.5

        # build agent_scores per class (normalized)
        agent_scores = {}
        total_votes = sum(votes.values()) + 1e-9
        for cls in CLASS_NAMES:
            agent_scores[cls] = votes[cls] / total_votes

        return {"agent_scores": agent_scores, "winner": winner, "winner_prob": winner_prob}

    def on_feature_message(self, client, userdata, msg):
        try:
            data = json.loads(msg.payload.decode())
            feat_dict = data.get("features", {})

            # ---- Time handling ----
            t = None
            if "Time" in feat_dict:
                try:
                    t = float(feat_dict["Time"])
                except Exception:
                    t = None

            # Detect new run once when Time ~ 0
            if t is not None:
                if (not self.have_seen_any) and t <= 1e-3:
                    self.reset_buffer()
                self.have_seen_any = True
                self.last_time = t

            # ---- Normal processing ----
            x_s = self.preprocess_single_vector(feat_dict)
            self.history_buffer.append(x_s)

            if len(self.history_buffer) < TIME_STEPS:
                print(f"[LIVE] Buffering... ({len(self.history_buffer)}/{TIME_STEPS})")
                return

            result = self.predict_fault()
            payload = {
                "timestamp": pd.Timestamp.now().isoformat(),
                "winner": result["winner"],
                "winner_prob": round(result["winner_prob"], 4),
                "agent_scores": {
                    k: round(v, 4) for k, v in result["agent_scores"].items()
                },
                "buffer_status": "READY",
            }
            out_msg = json.dumps(payload, ensure_ascii=False)
            print(f"[LIVE] Result: {result['winner']} ({result['winner_prob']:.2f})")
            client.publish(LIVE_RESULT_TOPIC, out_msg, qos=1)

        except Exception as e:
            print(f"[LIVE] Error: {e}")


def start_live_inference(client):
    live = LiveInference()
    if not live.load_artifacts():
        return

    client.message_callback_add(LIVE_FEATURE_TOPIC, live.on_feature_message)
    client.subscribe(LIVE_FEATURE_TOPIC)
    print(f"[LIVE] Listening on {LIVE_FEATURE_TOPIC}...")

    try:
        while True:
            time.sleep(1.0)
    except KeyboardInterrupt:
        print("[LIVE] Stopping live loop.")


# ---------------------------------------------------------------------
# MAIN
# ---------------------------------------------------------------------
def main():
    if not os.path.isdir(DATA_DIR):
        print(f"[ERROR] DATA_DIR '{DATA_DIR}' invalid.")
        return

    client = mqtt.Client(
        client_id=f"FaultDetection_Pairwise_{int(time.time())}",
        protocol=mqtt.MQTTv311,
        callback_api_version=CallbackAPIVersion.VERSION2,
    )
    client.on_connect = on_connect
    client.on_disconnect = on_disconnect
    client.on_message = on_message

    client.reconnect_delay_set(min_delay=1, max_delay=5)
client.connect(MQTT_BROKER, MQTT_PORT, keepalive=300)
client.loop_start()

try:
        # 1) Build global scaler (or load if already exists)
        if RUN_TRAINING:
            scaler, feature_names = build_global_scaler()
        else:
            # load existing scaler + features
            with open(os.path.join(ARTIFACTS_ROOT, "scaler_mean_std.json"), "r") as f:
                sc_obj = json.load(f)
            with open(os.path.join(ARTIFACTS_ROOT, "feature_names.json"), "r") as f:
                feature_names = json.load(f)
            scaler = StandardScaler()
            scaler.mean_ = np.array(sc_obj["mean"], dtype=float)
            scaler.scale_ = np.array(sc_obj["scale"], dtype=float)
            scaler.var_ = scaler.scale_ ** 2

        # 2) Train pairwise agents
        if RUN_TRAINING:
            results = []
            for cls_a, cls_b in PAIR_LIST:
                info = train_pair_agent(client, scaler, feature_names, cls_a, cls_b)
                if info:
                    results.append(info)
            if results:
                print("\n[TRAIN] Pairwise summary:")
                print(pd.DataFrame(results))

        # 3) Live inference loop
        start_live_inference(client)

    finally:
        client.loop_stop()
        client.disconnect()
        print("[MAIN] Done.")


if __name__ == "__main__":
    main()


IndentationError: unindent does not match any outer indentation level (<tokenize>, line 599)

In [1]:
import lstm_pairwise_main



In [3]:
lstm_pairwise_main.RUN_TRAINING = False
lstm_pairwise_main.main()


[MQTT] Connected rc=Success
[LIVE] Loading model for pair OM0 vs OM1
[LIVE] Loading model for pair OM0 vs OM2
[LIVE] Loading model for pair OM0 vs OM3
[LIVE] Loading model for pair OM0 vs OM4
[LIVE] Loading model for pair OM0 vs OM5
[LIVE] Loading model for pair OM0 vs OM6
[LIVE] Loading model for pair OM0 vs OM7
[LIVE] Loading model for pair OM0 vs OM8
[MQTT] Disconnected rc=Unspecified error
[LIVE] Loading model for pair OM1 vs OM2
[LIVE] Loading model for pair OM1 vs OM3
[LIVE] Loading model for pair OM1 vs OM4
[LIVE] Loading model for pair OM1 vs OM5
[LIVE] Loading model for pair OM1 vs OM6
[LIVE] Loading model for pair OM1 vs OM7
[LIVE] Loading model for pair OM1 vs OM8
[LIVE] Loading model for pair OM2 vs OM3
[MQTT] Connected rc=Success
[LIVE] Loading model for pair OM2 vs OM4
[LIVE] Loading model for pair OM2 vs OM5
[LIVE] Loading model for pair OM2 vs OM6
[LIVE] Loading model for pair OM2 vs OM7
[LIVE] Loading model for pair OM2 vs OM8
[MQTT] Disconnected rc=Unspecified error
[

In [None]:

# ======================= LSTM PAIRWISE AGENTS + MQTT LIVE =======================
import os
import json
import time
from collections import deque
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report

import tensorflow as tf
from tensorflow.keras import Sequential, layers
from tensorflow.keras.layers import LSTM, Dropout

import paho.mqtt.client as mqtt
from paho.mqtt.client import CallbackAPIVersion

# ---------------------------------------------------------------------
# PATHS  (CHANGE THIS TO YOUR FOLDER IF NEEDED)
# ---------------------------------------------------------------------
DATA_DIR = r"C:\Users\Muckbul\Desktop\agenticAI"  # <-- EDIT IF NEEDED
ARTIFACTS_ROOT = DATA_DIR  # where models + scaler + feature_names go

# ---------------------------------------------------------------------
# MQTT CONFIG
# ---------------------------------------------------------------------
MQTT_BROKER = "localhost"
MQTT_PORT = 1883

TOPIC_BASE = "fault_detection/report/pairwise"
LIVE_FEATURE_TOPIC = "fault_detection/live/features"
LIVE_RESULT_TOPIC = "fault_detection/live/result"

# ---------------------------------------------------------------------
# TRAINING / MODEL CONFIG
# ---------------------------------------------------------------------
TIME_STEPS = 10
SEED, TEST_SIZE, VAL_SIZE = 42, 0.2, 0.2
BATCH, EPOCHS, LR = 128, 20, 1e-4

RUN_TRAINING = True   # set False when models already trained

# use only first N rows from each OM file for training
MAX_ROWS_PER_FILE = 500

# if you want only data after a certain time, set a float here
FAULT_START_TIME: Optional[float] = None  # e.g. 500.0, else None

rng = np.random.RandomState(SEED)
tf.random.set_seed(SEED)

# CLASS NAMES (OM0..OM8)
CLASS_NAMES = [f"OM{i}" for i in range(0, 9)]  # 9 classes total
# All unordered pairs (i < j)
PAIR_LIST: List[Tuple[str, str]] = [
    (CLASS_NAMES[i], CLASS_NAMES[j])
    for i in range(len(CLASS_NAMES))
    for j in range(i + 1, len(CLASS_NAMES))
]


# ---------------------------------------------------------------------
# BASIC UTILITIES
# ---------------------------------------------------------------------
def load_numeric(
    path: str,
    max_rows: Optional[int] = None,
    fault_start_time: Optional[float] = None,
) -> pd.DataFrame:
    """
    Load numeric columns from Excel/CSV.
    If 'Time' column exists and fault_start_time is not None, keep only
    rows with Time >= fault_start_time. Then keep only numeric columns.
    Optionally limit to 'max_rows' rows (after filtering).
    """
    ext = os.path.splitext(path)[1].lower()
    if ext in (".xlsx", ".xls"):
        df = pd.read_excel(path)
    elif ext == ".csv":
        df = pd.read_csv(path)
    else:
        raise ValueError(f"Unsupported file type: {ext}")

    # Force-convert Time to numeric (if present)
    if "Time" in df.columns:
        df["Time"] = pd.to_numeric(df["Time"], errors="coerce")
        df = df.dropna(subset=["Time"])
        if fault_start_time is not None:
            df = df[df["Time"] >= fault_start_time].copy()

    # Numeric columns only
    num = df.select_dtypes(include=[np.number]).copy()
    num = num.replace([np.inf, -np.inf], np.nan).ffill().bfill()

    if max_rows is not None and len(num) > max_rows:
        num = num.iloc[:max_rows].copy()

    return num


def find_file(base: str) -> Optional[str]:
    p_xlsx = os.path.join(DATA_DIR, f"{base}.xlsx")
    p_csv = os.path.join(DATA_DIR, f"{base}.csv")
    if os.path.exists(p_xlsx):
        return p_xlsx
    if os.path.exists(p_csv):
        return p_csv
    return None


def create_sequences(X: np.ndarray, y: np.ndarray, time_steps: int = 10, step: int = 1):
    Xs, ys = [], []
    for i in range(0, len(X) - time_steps, step):
        Xs.append(X[i: i + time_steps])
        ys.append(y[i + time_steps - 1])
    return np.array(Xs), np.array(ys)


def oversample_indices(idx_array: np.ndarray, n_target: int) -> np.ndarray:
    if len(idx_array) == 0:
        return idx_array
    if n_target <= len(idx_array):
        return rng.choice(idx_array, size=n_target, replace=False)
    extra = rng.choice(idx_array, size=n_target - len(idx_array), replace=True)
    return np.concatenate([idx_array, extra], axis=0)


def make_topic(exp_name: str) -> str:
    safe = exp_name.replace(" ", "_").replace("(", "").replace(")", "")
    return f"{TOPIC_BASE}/{safe}"


# ---------------------------------------------------------------------
# MQTT CALLBACKS
# ---------------------------------------------------------------------
def on_connect(client, userdata, flags, reason_code, properties=None):
    print(f"[MQTT] Connected rc={reason_code}")


def on_disconnect(client, userdata, flags, reason_code, properties=None):
    print(f"[MQTT] Disconnected rc={reason_code}")
    # try to interpret reason_code as int (for paho v2)
    try:
        rc_val = int(reason_code)
    except Exception:
        rc_val = getattr(reason_code, "value", 0)

    # non-zero means unexpected disconnect -> try to reconnect
    if rc_val != 0:
        print("[MQTT] Trying to reconnect...")
        try:
            client.reconnect()
        except Exception as e:
            print(f"[MQTT] Reconnect failed: {e}")


def on_message(client, userdata, msg):
    # Not used here; kept for completeness
    pass


def publish_mqtt_report(client, exp_name: str, accuracy: float, report_dict: dict):
    topic = make_topic(exp_name)
    payload = {
        "agent": exp_name,
        "model_type": "LSTM_pairwise",
        "overall_test_accuracy": round(float(accuracy), 4),
        "weighted_f1": round(float(report_dict["weighted avg"]["f1-score"]), 4),
        "timestamp": pd.Timestamp.now().isoformat(),
    }
    msg = json.dumps(payload, ensure_ascii=False)
    print(f"[MQTT] Publish summary -> {topic}")
    if client:
        try:
            client.publish(topic, msg, qos=1)
        except Exception as e:
            print(f"[MQTT] Publish error: {e}")


# ---------------------------------------------------------------------
# GLOBAL SCALER (shared by all pairwise agents)
# ---------------------------------------------------------------------
def build_global_scaler() -> Tuple[StandardScaler, List[str]]:
    """
    Load all OM0..OM8 data (limited to MAX_ROWS_PER_FILE) and fit one
    StandardScaler on the concatenated dataset. Save mean/scale and
    feature_names to ARTIFACTS_ROOT.
    """
    X_all_parts = []
    feature_names: Optional[List[str]] = None

    for cls in CLASS_NAMES:
        p = find_file(cls)
        if p is None:
            print(f"[GLOBAL SCALER] WARNING: file for {cls} not found; skipping.")
            continue
        df = load_numeric(p, max_rows=MAX_ROWS_PER_FILE, fault_start_time=FAULT_START_TIME)
        if df.empty:
            print(f"[GLOBAL SCALER] WARNING: {cls} loaded but empty; skipping.")
            continue

        if feature_names is None:
            feature_names = df.columns.tolist()
        else:
            df = df.reindex(columns=feature_names)

        X_all_parts.append(df.to_numpy())

    if not X_all_parts or feature_names is None:
        raise RuntimeError("[GLOBAL SCALER] No data loaded from OM files.")

    X_all = np.concatenate(X_all_parts, axis=0)
    X_all = np.nan_to_num(X_all, nan=0.0)

    scaler = StandardScaler().fit(X_all)

    scaler_obj = {"mean": scaler.mean_.tolist(), "scale": scaler.scale_.tolist()}
    os.makedirs(ARTIFACTS_ROOT, exist_ok=True)
    with open(os.path.join(ARTIFACTS_ROOT, "scaler_mean_std.json"), "w") as f:
        json.dump(scaler_obj, f)
    with open(os.path.join(ARTIFACTS_ROOT, "feature_names.json"), "w") as f:
        json.dump(feature_names, f)

    print("[GLOBAL SCALER] Fitted scaler on all classes.")
    print(f"[GLOBAL SCALER] Features: {feature_names}")

    return scaler, feature_names


# ---------------------------------------------------------------------
# TRAIN ONE PAIRWISE AGENT: (cls_a vs cls_b)
# ---------------------------------------------------------------------
def train_pair_agent(
    client,
    scaler: StandardScaler,
    feature_names: List[str],
    cls_a: str,
    cls_b: str,
) -> Optional[Dict[str, Any]]:
    """
    Train LSTM for pair (cls_a vs cls_b).
    Label y=0 for cls_a, y=1 for cls_b.
    """
    exp_name = f"Agent {cls_a}_vs_{cls_b}"
    print("\n" + "=" * 80)
    print(f"TRAINING: {exp_name}")
    print("=" * 80)

    # ---- Load data for both classes ----
    X_parts, y_parts = [], []

    for cls, label in [(cls_a, 0), (cls_b, 1)]:
        p = find_file(cls)
        if p is None:
            print(f"[PAIR {cls_a} vs {cls_b}] File for {cls} not found; skipping pair.")
            return None
        df = load_numeric(p, max_rows=MAX_ROWS_PER_FILE, fault_start_time=FAULT_START_TIME)
        if df.empty:
            print(f"[PAIR {cls_a} vs {cls_b}] Empty data for {cls}; skipping pair.")
            return None

        # align columns
        df = df.reindex(columns=feature_names)
        X = df.to_numpy()
        X = np.nan_to_num(X, nan=0.0)

        # use global scaler
        X_scaled = scaler.transform(X)

        X_parts.append(X_scaled)
        y_parts.append(np.full(len(df), label, dtype=int))

    X_all = np.concatenate(X_parts, axis=0)
    y_all = np.concatenate(y_parts, axis=0)

    # ---- Create sequences ----
    X_seq, y_seq = create_sequences(X_all, y_all, time_steps=TIME_STEPS)
    print(
        f"[PAIR {cls_a} vs {cls_b}] X_seq={X_seq.shape}, "
        f"y_seq={y_seq.shape}, #0={(y_seq == 0).sum()}, #1={(y_seq == 1).sum()}"
    )
    if len(X_seq) < 2 * TIME_STEPS:
        print(f"[PAIR {cls_a} vs {cls_b}] Not enough sequences; skipping.")
        return None

    # ---- Split train/val/test ----
    X_tr, X_te, y_tr, y_te = train_test_split(
        X_seq, y_seq, test_size=TEST_SIZE, random_state=SEED, stratify=y_seq
    )
    X_tr, X_val, y_tr, y_val = train_test_split(
        X_tr, y_tr, test_size=VAL_SIZE, random_state=SEED, stratify=y_tr
    )

    # ---- Balance classes by oversampling ----
    idx = np.arange(len(y_tr))
    idx_a = idx[y_tr == 0]
    idx_b = idx[y_tr == 1]
    target = max(len(idx_a), len(idx_b))
    idx_a_bal = oversample_indices(idx_a, target)
    idx_b_bal = oversample_indices(idx_b, target)
    idx_final = np.concatenate([idx_a_bal, idx_b_bal])
    rng.shuffle(idx_final)

    X_tr_bal = X_tr[idx_final]
    y_tr_bal = y_tr[idx_final]

    print(
        f"[PAIR {cls_a} vs {cls_b}] After balancing: "
        f"#0={(y_tr_bal == 0).sum()}, #1={(y_tr_bal == 1).sum()}"
    )

    # ---- Build & train LSTM ----
    n_features = X_tr_bal.shape[2]
    model = Sequential(
        [
            layers.Input(shape=(TIME_STEPS, n_features)),
            LSTM(64, return_sequences=True),
            Dropout(0.2),
            LSTM(32, return_sequences=False),
            Dropout(0.2),
            layers.Dense(32, activation="relu"),
            layers.Dense(1, activation="sigmoid"),
        ],
        name=f"LSTM_{cls_a}_vs_{cls_b}",
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LR),
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )

    print(f"[PAIR {cls_a} vs {cls_b}] Training LSTM...")
    model.fit(
        X_tr_bal,
        y_tr_bal,
        validation_data=(X_val, y_val),
        epochs=EPOCHS,
        batch_size=BATCH,
        verbose=1,
    )

    # ---- Evaluate ----
    probs_te = model.predict(X_te, verbose=0).flatten()
    y_pred_te = (probs_te > 0.5).astype(int)

    report = classification_report(
        y_te,
        y_pred_te,
        target_names=[cls_a, cls_b],
        output_dict=True,
        zero_division=0,
    )
    test_acc = report["accuracy"]
    print(f"[PAIR {cls_a} vs {cls_b}] Test accuracy: {test_acc:.4f}")

    # ---- Save model ----
    safe_name = exp_name.replace(" ", "_").replace("(", "").replace(")", "")
    save_dir = os.path.join(ARTIFACTS_ROOT, f"{safe_name}_artifacts")
    os.makedirs(save_dir, exist_ok=True)
    model.save(os.path.join(save_dir, "model_saved.keras"))

    publish_mqtt_report(client, exp_name, test_acc, report)

    return {"Pair": f"{cls_a}_vs_{cls_b}", "TestAcc": test_acc}


# ---------------------------------------------------------------------
# LIVE INFERENCE (pairwise voting)
# ---------------------------------------------------------------------
class LiveInference:
    def __init__(self):
        self.feature_names: List[str] = []
        self.scaler_mean: np.ndarray = None
        self.scaler_scale: np.ndarray = None

        # list of (cls_a, cls_b, model)
        self.pair_models: List[Tuple[str, str, tf.keras.Model]] = []

        self.history_buffer = deque(maxlen=TIME_STEPS)
        self.last_time: Optional[float] = None
        self.have_seen_any: bool = False

    def load_artifacts(self) -> bool:
        scaler_path = os.path.join(ARTIFACTS_ROOT, "scaler_mean_std.json")
        feat_path = os.path.join(ARTIFACTS_ROOT, "feature_names.json")

        if not (os.path.exists(scaler_path) and os.path.exists(feat_path)):
            print("[LIVE] Shared scaler/feature_names not found.")
            return False

        with open(scaler_path, "r") as f:
            scaler_obj = json.load(f)
        with open(feat_path, "r") as f:
            self.feature_names = json.load(f)

        self.scaler_mean = np.array(scaler_obj["mean"], dtype=float)
        self.scaler_scale = np.array(scaler_obj["scale"], dtype=float)
        self.scaler_scale[self.scaler_scale == 0] = 1.0

        # load all available pairwise models
        for cls_a, cls_b in PAIR_LIST:
            exp_name = f"Agent {cls_a}_vs_{cls_b}"
            folder = os.path.join(
                ARTIFACTS_ROOT,
                f"{exp_name.replace(' ', '_').replace('(', '').replace(')', '')}_artifacts",
            )
            model_path = os.path.join(folder, "model_saved.keras")
            if os.path.exists(model_path):
                print(f"[LIVE] Loading model for pair {cls_a} vs {cls_b}")
                model = tf.keras.models.load_model(model_path)
                self.pair_models.append((cls_a, cls_b, model))

        if not self.pair_models:
            print("[LIVE] No pairwise models loaded.")
            return False

        print(f"[LIVE] Using features ({len(self.feature_names)}): {self.feature_names}")
        print(f"[LIVE] Loaded {len(self.pair_models)} pairwise agents.")
        return True

    def reset_buffer(self):
        self.history_buffer.clear()
        self.last_time = None
        self.have_seen_any = False
        print("[LIVE] Buffer reset (new simulation).")

    def preprocess_single_vector(self, feat_dict: Dict[str, Any]) -> np.ndarray:
        x = np.zeros(len(self.feature_names), dtype=float)
        for i, name in enumerate(self.feature_names):
            try:
                x[i] = float(feat_dict.get(name, 0.0))
            except Exception:
                x[i] = 0.0
        return (x - self.scaler_mean) / self.scaler_scale

    def predict_fault(self) -> Dict[str, Any]:
        seq_data = np.array(self.history_buffer)  # (T,F)
        X_in = np.expand_dims(seq_data, axis=0)   # (1,T,F)

        votes = {cls: 0 for cls in CLASS_NAMES}
        score_sums = {cls: 0.0 for cls in CLASS_NAMES}

        for cls_a, cls_b, model in self.pair_models:
            prob_b = float(model.predict(X_in, verbose=0)[0, 0])
            prob_a = 1.0 - prob_b

            if prob_b >= 0.5:
                winner = cls_b
                prob = prob_b
            else:
                winner = cls_a
                prob = prob_a

            votes[winner] += 1
            score_sums[winner] += prob

        # pick class with maximum votes; break ties with summed score
        best_cls = None
        best_votes = -1
        best_score = -1.0
        for cls in CLASS_NAMES:
            v = votes[cls]
            s = score_sums[cls]
            if v > best_votes or (v == best_votes and s > best_score):
                best_cls = cls
                best_votes = v
                best_score = s

        winner = best_cls if best_cls is not None else "OM0"
        # normalize probability: average of winner's scores
        if best_votes > 0:
            winner_prob = best_score / best_votes
        else:
            winner_prob = 0.5

        # build agent_scores per class (normalized)
        agent_scores = {}
        total_votes = sum(votes.values()) + 1e-9
        for cls in CLASS_NAMES:
            agent_scores[cls] = votes[cls] / total_votes

        return {"agent_scores": agent_scores, "winner": winner, "winner_prob": winner_prob}

    def on_feature_message(self, client, userdata, msg):
        try:
            data = json.loads(msg.payload.decode())
            feat_dict = data.get("features", {})

            # ---- Time handling ----
            t = None
            if "Time" in feat_dict:
                try:
                    t = float(feat_dict["Time"])
                except Exception:
                    t = None

            # Detect new run once when Time ~ 0
            if t is not None:
                if (not self.have_seen_any) and t <= 1e-3:
                    self.reset_buffer()
                self.have_seen_any = True
                self.last_time = t

            # ---- Normal processing ----
            x_s = self.preprocess_single_vector(feat_dict)
            self.history_buffer.append(x_s)

            if len(self.history_buffer) < TIME_STEPS:
                print(f"[LIVE] Buffering... ({len(self.history_buffer)}/{TIME_STEPS})")
                return

            result = self.predict_fault()
            payload = {
                "timestamp": pd.Timestamp.now().isoformat(),
                "winner": result["winner"],
                "winner_prob": round(result["winner_prob"], 4),
                "agent_scores": {
                    k: round(v, 4) for k, v in result["agent_scores"].items()
                },
                "buffer_status": "READY",
            }
            out_msg = json.dumps(payload, ensure_ascii=False)
            print(f"[LIVE] Result: {result['winner']} ({result['winner_prob']:.2f})")
            client.publish(LIVE_RESULT_TOPIC, out_msg, qos=1)

        except Exception as e:
            print(f"[LIVE] Error: {e}")


def start_live_inference(client):
    """
    Attach callbacks and block in a simple loop.
    Call this after training / loading models.
    """
    live = LiveInference()
    if not live.load_artifacts():
        return

    client.message_callback_add(LIVE_FEATURE_TOPIC, live.on_feature_message)
    client.subscribe(LIVE_FEATURE_TOPIC)
    print(f"[LIVE] Listening on {LIVE_FEATURE_TOPIC}...")

    try:
        while True:
            time.sleep(1.0)
    except KeyboardInterrupt:
        print("[LIVE] Stopping live loop.")


# ---------------------------------------------------------------------
# MAIN
# ---------------------------------------------------------------------
def main():
    if not os.path.isdir(DATA_DIR):
        print(f"[ERROR] DATA_DIR '{DATA_DIR}' invalid.")
        return

    # ------- MQTT CLIENT (stable) -------
    unique_id = f"FaultDetection_Pairwise_{int(time.time())}"

    client = mqtt.Client(
        client_id=unique_id,
        protocol=mqtt.MQTTv311,
        callback_api_version=CallbackAPIVersion.VERSION2,
    )
    client.on_connect = on_connect
    client.on_disconnect = on_disconnect
    client.on_message = on_message

    # reconnection behaviour
    client.reconnect_delay_set(min_delay=1, max_delay=5)

    # connect and start background loop
    client.connect(MQTT_BROKER, MQTT_PORT, keepalive=300)
    client.loop_start()

    try:
        # 1) Build global scaler (or load if already exists)
        if RUN_TRAINING:
            scaler, feature_names = build_global_scaler()
        else:
            with open(os.path.join(ARTIFACTS_ROOT, "scaler_mean_std.json"), "r") as f:
                sc_obj = json.load(f)
            with open(os.path.join(ARTIFACTS_ROOT, "feature_names.json"), "r") as f:
                feature_names = json.load(f)
            scaler = StandardScaler()
            scaler.mean_ = np.array(sc_obj["mean"], dtype=float)
            scaler.scale_ = np.array(sc_obj["scale"], dtype=float)
            scaler.var_ = scaler.scale_ ** 2

        # 2) Train pairwise agents
        if RUN_TRAINING:
            results = []
            for cls_a, cls_b in PAIR_LIST:
                info = train_pair_agent(client, scaler, feature_names, cls_a, cls_b)
                if info:
                    results.append(info)
            if results:
                print("\n[TRAIN] Pairwise summary:")
                print(pd.DataFrame(results))

        # 3) Live inference loop
        start_live_inference(client)

    finally:
        client.loop_stop()
        client.disconnect()
        print("[MAIN] Done.")


if __name__ == "__main__":
    main()


[MQTT] Connected rc=Success
[GLOBAL SCALER] Fitted scaler on all classes.
[GLOBAL SCALER] Features: ['Time', 'Data_ 1', 'Data_ 2', 'Data_ 3', 'Data_ 4', 'Data_ 5', 'Data_ 6', 'Data_ 7', 'Data_ 8', 'Data_ 9', 'Data_10', 'Data_11', 'Data_12', 'Data_13', 'Data_14', 'Data_15', 'Data_16', 'Data_17']

TRAINING: Agent OM0_vs_OM1
[PAIR OM0 vs OM1] X_seq=(990, 10, 18), y_seq=(990,), #0=491, #1=499
[PAIR OM0 vs OM1] After balancing: #0=319, #1=319
[PAIR OM0 vs OM1] Training LSTM...
Epoch 1/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 123ms/step - accuracy: 0.3433 - loss: 0.7126 - val_accuracy: 0.2830 - val_loss: 0.7093
Epoch 2/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.4107 - loss: 0.7049 - val_accuracy: 0.5031 - val_loss: 0.6996
Epoch 3/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.4953 - loss: 0.6954 - val_accuracy: 0.5031 - val_loss: 0.6909
Epoch 4/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.3495 - loss: 0.7357 - val_accuracy: 0.3774 - val_loss: 0.7222
Epoch 3/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.4216 - loss: 0.7140 - val_accuracy: 0.4906 - val_loss: 0.7036
Epoch 4/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.4812 - loss: 0.7004 - val_accuracy: 0.6667 - val_loss: 0.6866
Epoch 5/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6019 - loss: 0.6874 - val_accuracy: 0.8931 - val_loss: 0.6724
Epoch 6/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.7257 - loss: 0.6734 - val_accuracy: 0.9308 - val_loss: 0.6598
Epoch 7/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.7743 - loss: 0.6645 - val_accuracy: 0.9748 - val_loss: 0.6481
Epoch 8/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m

[PAIR OM0 vs OM5] X_seq=(990, 10, 18), y_seq=(990,), #0=491, #1=499
[PAIR OM0 vs OM5] After balancing: #0=319, #1=319
[PAIR OM0 vs OM5] Training LSTM...
Epoch 1/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 120ms/step - accuracy: 0.5768 - loss: 0.6608 - val_accuracy: 0.4906 - val_loss: 0.6467
Epoch 2/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.6489 - loss: 0.6480 - val_accuracy: 0.6415 - val_loss: 0.6302
Epoch 3/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.7351 - loss: 0.6325 - val_accuracy: 0.9182 - val_loss: 0.6141
Epoch 4/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7994 - loss: 0.6186 - val_accuracy: 0.9811 - val_loss: 0.5985
Epoch 5/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8809 - loss: 0.6027 - val_accuracy: 0.9811 - val_loss: 0.5839
Epoch 6/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9216 - loss: 0.6182 - val_accuracy: 0.9937 - val_loss: 0.6019
Epoch 6/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9718 - loss: 0.6037 - val_accuracy: 0.9937 - val_loss: 0.5868
Epoch 7/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9781 - loss: 0.5921 - val_accuracy: 0.9937 - val_loss: 0.5715
Epoch 8/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.9875 - loss: 0.5768 - val_accuracy: 0.9937 - val_loss: 0.5562
Epoch 9/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.9843 - loss: 0.5627 - val_accuracy: 0.9937 - val_loss: 0.5408
Epoch 10/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9922 - loss: 0.5469 - val_accuracy: 0.9937 - val_loss: 0.5252
Epoch 11/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[3

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.8856 - loss: 0.6221 - val_accuracy: 0.9497 - val_loss: 0.5980
Epoch 11/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8652 - loss: 0.6089 - val_accuracy: 0.9371 - val_loss: 0.5842
Epoch 12/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.8934 - loss: 0.5974 - val_accuracy: 0.9371 - val_loss: 0.5698
Epoch 13/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9107 - loss: 0.5830 - val_accuracy: 0.9308 - val_loss: 0.5542
Epoch 14/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.8966 - loss: 0.5703 - val_accuracy: 0.9308 - val_loss: 0.5381
Epoch 15/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9060 - loss: 0.5588 - val_accuracy: 0.9308 - val_loss: 0.5218
Epoch 16/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.8135 - loss: 0.5891 - val_accuracy: 0.8176 - val_loss: 0.5678
Epoch 16/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - accuracy: 0.8150 - loss: 0.5755 - val_accuracy: 0.8239 - val_loss: 0.5548
Epoch 17/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.8307 - loss: 0.5636 - val_accuracy: 0.8365 - val_loss: 0.5408
Epoch 18/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8370 - loss: 0.5516 - val_accuracy: 0.8365 - val_loss: 0.5257
Epoch 19/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.8370 - loss: 0.5346 - val_accuracy: 0.8428 - val_loss: 0.5097
Epoch 20/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.8417 - loss: 0.5184 - val_accuracy: 0.8553 - val_loss: 0.4927
[PAIR OM1 vs OM4] Test accuracy: 0.8687
[MQTT] Publi

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step - accuracy: 0.9781 - loss: 0.3434 - val_accuracy: 0.9874 - val_loss: 0.3080
[PAIR OM1 vs OM6] Test accuracy: 0.9899
[MQTT] Publish summary -> fault_detection/report/pairwise/Agent_OM1_vs_OM6

TRAINING: Agent OM1_vs_OM7
[PAIR OM1 vs OM7] X_seq=(990, 10, 18), y_seq=(990,), #0=491, #1=499
[PAIR OM1 vs OM7] After balancing: #0=319, #1=319
[PAIR OM1 vs OM7] Training LSTM...
Epoch 1/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 122ms/step - accuracy: 0.3495 - loss: 0.7204 - val_accuracy: 0.4340 - val_loss: 0.7119
Epoch 2/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.4389 - loss: 0.7023 - val_accuracy: 0.5597 - val_loss: 0.6935
Epoch 3/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.5611 - loss: 0.6855 - val_accuracy: 0.5849 - val_loss: 0.6757
Epoch 4/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/

Epoch 3/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9734 - loss: 0.5883 - val_accuracy: 0.9937 - val_loss: 0.5686
Epoch 4/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step - accuracy: 0.9796 - loss: 0.5687 - val_accuracy: 0.9811 - val_loss: 0.5465
Epoch 5/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9718 - loss: 0.5491 - val_accuracy: 0.9811 - val_loss: 0.5249
Epoch 6/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9639 - loss: 0.5301 - val_accuracy: 0.9811 - val_loss: 0.5036
Epoch 7/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.9592 - loss: 0.5080 - val_accuracy: 0.9811 - val_loss: 0.4825
Epoch 8/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9624 - loss: 0.4859 - val_accuracy: 0.9811 - val_loss: 0.4616
Epoch 9/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9718 - loss: 0.5258 - val_accuracy: 0.9811 - val_loss: 0.5011
Epoch 9/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9655 - loss: 0.5091 - val_accuracy: 0.9811 - val_loss: 0.4822
Epoch 10/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.9765 - loss: 0.4936 - val_accuracy: 0.9811 - val_loss: 0.4635
Epoch 11/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9765 - loss: 0.4735 - val_accuracy: 0.9811 - val_loss: 0.4453
Epoch 12/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9781 - loss: 0.4595 - val_accuracy: 0.9811 - val_loss: 0.4273
Epoch 13/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9749 - loss: 0.4394 - val_accuracy: 0.9811 - val_loss: 0.4094
Epoch 14/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.9702 - loss: 0.4699 - val_accuracy: 0.9748 - val_loss: 0.4388
Epoch 14/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9687 - loss: 0.4516 - val_accuracy: 0.9748 - val_loss: 0.4203
Epoch 15/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.9749 - loss: 0.4326 - val_accuracy: 0.9811 - val_loss: 0.4018
Epoch 16/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9718 - loss: 0.4152 - val_accuracy: 0.9874 - val_loss: 0.3835
Epoch 17/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9765 - loss: 0.3963 - val_accuracy: 0.9874 - val_loss: 0.3655
Epoch 18/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9796 - loss: 0.3799 - val_accuracy: 0.9937 - val_loss: 0.3477
Epoch 19/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step - accuracy: 0.9671 - loss: 0.5097 - val_accuracy: 0.9937 - val_loss: 0.4771
Epoch 19/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9796 - loss: 0.4898 - val_accuracy: 0.9937 - val_loss: 0.4568
Epoch 20/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9843 - loss: 0.4660 - val_accuracy: 0.9937 - val_loss: 0.4355
[PAIR OM3 vs OM4] Test accuracy: 0.9949
[MQTT] Publish summary -> fault_detection/report/pairwise/Agent_OM3_vs_OM4

TRAINING: Agent OM3_vs_OM5
[PAIR OM3 vs OM5] X_seq=(990, 10, 18), y_seq=(990,), #0=491, #1=499
[PAIR OM3 vs OM5] After balancing: #0=319, #1=319
[PAIR OM3 vs OM5] Training LSTM...
Epoch 1/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 122ms/step - accuracy: 0.6191 - loss: 0.6099 - val_accuracy: 0.5912 - val_loss: 0.5908
Epoch 2/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28m

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 120ms/step - accuracy: 0.6129 - loss: 0.6786 - val_accuracy: 0.7358 - val_loss: 0.6681
Epoch 2/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.7571 - loss: 0.6613 - val_accuracy: 0.9119 - val_loss: 0.6508
Epoch 3/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.8448 - loss: 0.6489 - val_accuracy: 0.9497 - val_loss: 0.6327
Epoch 4/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.8809 - loss: 0.6317 - val_accuracy: 0.9748 - val_loss: 0.6141
Epoch 5/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9248 - loss: 0.6168 - val_accuracy: 0.9748 - val_loss: 0.5954
Epoch 6/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9483 - loss: 0.5985 - val_accuracy: 0.9748 - val_loss: 0.5770
Epoch 7/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9639 - loss: 0.5789 - val_accuracy: 0.9811 - val_loss: 0.5567
Epoch 7/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9749 - loss: 0.5634 - val_accuracy: 0.9811 - val_loss: 0.5416
Epoch 8/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9655 - loss: 0.5489 - val_accuracy: 0.9811 - val_loss: 0.5260
Epoch 9/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.9671 - loss: 0.5357 - val_accuracy: 0.9811 - val_loss: 0.5105
Epoch 10/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.9734 - loss: 0.5234 - val_accuracy: 0.9811 - val_loss: 0.4952
Epoch 11/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.9702 - loss: 0.5057 - val_accuracy: 0.9748 - val_loss: 0.4800
Epoch 12/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6034 - loss: 0.6017 - val_accuracy: 0.6667 - val_loss: 0.5831
Epoch 12/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6614 - loss: 0.5864 - val_accuracy: 0.7296 - val_loss: 0.5697
Epoch 13/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.7132 - loss: 0.5757 - val_accuracy: 0.7673 - val_loss: 0.5562
Epoch 14/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.7445 - loss: 0.5618 - val_accuracy: 0.7925 - val_loss: 0.5425
Epoch 15/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.7774 - loss: 0.5517 - val_accuracy: 0.8868 - val_loss: 0.5286
Epoch 16/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.8574 - loss: 0.5365 - val_accuracy: 0.9623 - val_loss: 0.5144
Epoch 17/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.5705 - loss: 0.6800 - val_accuracy: 0.5535 - val_loss: 0.6775
Epoch 17/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - accuracy: 0.6301 - loss: 0.6742 - val_accuracy: 0.5535 - val_loss: 0.6755
Epoch 18/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - accuracy: 0.5831 - loss: 0.6730 - val_accuracy: 0.5597 - val_loss: 0.6733
Epoch 19/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.5972 - loss: 0.6726 - val_accuracy: 0.5597 - val_loss: 0.6707
Epoch 20/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6505 - loss: 0.6671 - val_accuracy: 0.5597 - val_loss: 0.6680
[PAIR OM5 vs OM6] Test accuracy: 0.5556
[MQTT] Publish summary -> fault_detection/report/pairwise/Agent_OM5_vs_OM6

TRAINING: Agent OM5_vs_OM7
[PAIR OM5 vs OM7] X_seq=(990, 10, 18), y_seq=(990,), #0=491, #1=499
[

[PAIR OM6 vs OM7] X_seq=(990, 10, 18), y_seq=(990,), #0=491, #1=499
[PAIR OM6 vs OM7] After balancing: #0=319, #1=319
[PAIR OM6 vs OM7] Training LSTM...
Epoch 1/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 123ms/step - accuracy: 0.4969 - loss: 0.6846 - val_accuracy: 0.4969 - val_loss: 0.6803
Epoch 2/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.5439 - loss: 0.6781 - val_accuracy: 0.5849 - val_loss: 0.6740
Epoch 3/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.5862 - loss: 0.6714 - val_accuracy: 0.6478 - val_loss: 0.6676
Epoch 4/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step - accuracy: 0.6238 - loss: 0.6659 - val_accuracy: 0.6478 - val_loss: 0.6607
Epoch 5/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/step - accuracy: 0.6066 - loss: 0.6617 - val_accuracy: 0.6541 - val_loss: 0.6534
Epoch 6/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - accuracy: 0.9091 - loss: 0.6273 - val_accuracy: 0.9245 - val_loss: 0.6151
Epoch 6/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step - accuracy: 0.9295 - loss: 0.6157 - val_accuracy: 0.9371 - val_loss: 0.6017
Epoch 7/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step - accuracy: 0.9389 - loss: 0.6010 - val_accuracy: 0.9623 - val_loss: 0.5884
Epoch 8/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.9608 - loss: 0.5893 - val_accuracy: 0.9811 - val_loss: 0.5747
Epoch 9/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step - accuracy: 0.9592 - loss: 0.5796 - val_accuracy: 0.9937 - val_loss: 0.5608
Epoch 10/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step - accuracy: 0.9749 - loss: 0.5626 - val_accuracy: 0.9937 - val_loss: 0.5463
Epoch 11/20
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[3

In [3]:
import lstm_pairwise_main


In [1]:
lstm_pairwise_main.RUN_TRAINING = True
lstm_pairwise_main.main()   # cell will run until you stop it


NameError: name 'lstm_pairwise_main' is not defined