In [None]:
###########################   MULTI TRAINING ON CLEAN LOG   #########################################

########################   TRAINING ON CLEAN LOG   ##################################



import math, random, time, re, json
from collections import defaultdict
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import RobustScaler
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from statsmodels import robust

import os, pickle, pathlib, re

import matplotlib.pyplot as plt
import seaborn as sns

######### Configuration Constants
SEED = 42
LATENT_DIM = 32
DROPOUT = 0.15
LR = 1e-5
BATCH = 64
EPOCHS_MAX = 200
PATIENCE = 12
PCA_VAR = 0.97
PCA_MAX_FRAC = 0.5
JITTER_STD = 0.02

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

########### ----------------- helpers ----------------- ####
def quat_angle_deg(q1, q2):
    v = float(np.clip(abs(np.dot(q1, q2)), 0.0, 1.0))
    ang = 2.0 * math.degrees(math.acos(v))
    return ang if np.isfinite(ang) else 0.0

def safe_mom(x):
    x = np.asarray(x, float)
    n = x.size
    if n == 0:
        return (0.0,) * 6
    mean = float(np.mean(x))
    if n < 2:
        return (mean, 0.0, float(x.min()), float(x.max()), 0.0, 0.0)
    std = float(np.std(x, ddof = 1))
    mn, mx = float(x.min()), float(x.max())
    if std <= 1e-12:
        return (mean, std, mn, mx, 0.0, 0.0)
    xc = x - mean
    m2 = np.mean(xc ** 2)
    m3 = np.mean(xc ** 3)
    m4 = np.mean(xc ** 4)
    skew = m3 / (m2 ** 1.5 + 1e-12)
    kurt = m4 / (m2 ** 2 + 1e-12) - 3.0
    if not np.isfinite(skew):
        skew = 0.0
    if not np.isfinite(kurt):
        kurt = 0.0
    return (mean, std, mn, mx, float(skew), float(kurt))

def add_comp(prefix, arr, feat, cols):
    if arr.size == 0:
        arr = np.zeros((0, 3))
    for j in range(arr.shape[1]):
        for k, v in zip(
            ("mean", "std", "min", "max", "skew", "kurt"),
            safe_mom(arr[:, j]),
        ):
            feat.append(v)
            cols.append(f"{prefix}_{j}_{k}")

def add_scal(prefix, x, feat, cols):
    for k, v in zip(("mean", "std", "min", "max", "skew", "kurt"), safe_mom(x)):
        feat.append(v)
        cols.append(f"{prefix}_{k}")

####### ----------------- log parser ----------------- ########
NUM_RE = re.compile(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?")

def _nums(l):
    return [float(t) for t in NUM_RE.findall(l)]

def _its(x):
    return int(round(x))

def parse_log(path):
    slows, fasts, imu = [], [], defaultdict(list)
    with open(path, "r", errors = "ignore") as fh:
        for ln in fh:
            l = ln.strip()
            if not l:
                continue
            up = l.upper()
            nums = _nums(l)
            if not nums:
                continue
            n = len(nums)
            if ("SLOW" in up) or (n == 17):
                tag = "SLOW"
            elif ("FAST" in up) or (n == 23):
                tag = "FAST"
            elif ("IMU" in up) or (n == 7):
                tag = "IMU"
            else:
                continue

            if tag == "SLOW":
                slows.append(
                    dict(
                        ts = _its(nums[0]),
                        ba = np.float32(nums[1:4]),
                        bg = np.float32(nums[4:7]),
                        pos = np.float32(nums[7:10]),
                        vel = np.float32(nums[10:13]),
                        quat = np.float32(nums[13:17]),
                    )
                )
            elif tag == "FAST":
                fasts.append(
                    dict(
                        ts = _its(nums[0]),
                        pos = np.float32(nums[1:4]),
                        vel = np.float32(nums[4:7]),
                        quat = np.float32(nums[7:11]),
                        b_acc = np.float32(nums[11:14]),
                        b_gyro = np.float32(nums[14:17]),
                        b_acc_prev = np.float32(nums[17:20]),
                        b_gyro_prev = np.float32(nums[20:23]),
                    )
                )
            else:
                imu[_its(nums[0])].append(
                    dict(w = np.float32(nums[1:4]), a = np.float32(nums[4:7]))
                )

    slows.sort(key = lambda d: d["ts"])
    fasts.sort(key = lambda d: d["ts"])
    return slows, fasts, imu

#################### ----------------- feature builder -----------------##############
def build(slows, fasts, imu) -> Tuple[pd.DataFrame, List[str]]:
    fts = np.array([f["ts"] for f in fasts], np.int64)
    rows = []
    cols_final = None  # <-- Initialize cols
    for i, s in enumerate(slows[1:], 1):
        prev, cur = slows[i - 1]["ts"], s["ts"]
        idx = np.nonzero((fts >= prev) & (fts < cur))[0]
        if idx.size == 0:
            continue
        fw = [fasts[j] for j in idx]
        feat, cols = [], []

        dpos = np.stack([f["pos"] - s["pos"] for f in fw])
        add_comp("dpos", dpos, feat, cols)
        add_scal("dpos_norm", np.linalg.norm(dpos, 1), feat, cols)

        dvel = np.stack([f["vel"] - s["vel"] for f in fw])
        add_comp("dvel", dvel, feat, cols)
        add_scal("dvel_norm", np.linalg.norm(dvel, 1), feat, cols)

        add_scal("q_angle_deg",
                 [quat_angle_deg(f["quat"], s["quat"]) for f in fw],
                 feat, cols)


        for key in ("b_acc", "b_gyro", "b_acc_prev", "b_gyro_prev"):
            for i, f in enumerate(fw):
                arr = f[key]
                if not isinstance(arr, np.ndarray) or arr.shape != (3,):
                    print(f"[DEBUG] At index {i}, key='{key}', bad shape: {arr.shape}, value: {arr}")
            add_comp(key, np.stack([f[key] for f in fw]), feat, cols)

        iw, ia = [], []
        for f in fw:
            for sm in imu.get(int(f["ts"]), []):
                iw.append(sm["w"])
                ia.append(sm["a"])
        add_comp("imu_w", np.stack(iw) if iw else np.zeros((0, 3)), feat, cols)
        add_scal("imu_w_norm",
                 np.linalg.norm(iw, 1) if iw else np.zeros(0), feat, cols)
        add_comp("imu_a", np.stack(ia) if ia else np.zeros((0, 3)), feat, cols)
        add_scal("imu_a_norm",
                 np.linalg.norm(ia, 1) if ia else np.zeros(0), feat, cols)

        for k in ("pos", "vel", "ba", "bg"):
            for j, v in enumerate(s[k]):
                feat.append(float(v))
                cols.append(f"slow_{k}_{j}")

        feat.insert(0, (cur - prev) / 1e9)
        cols.insert(0, "win_dur_s")
        feat.insert(0, float(len(idx)))
        cols.insert(0, "win_n_fast")
        rows.append(feat)

        if cols_final is None:
            cols_final = cols  # Save first valid columns

    if not rows or cols_final is None:
        return pd.DataFrame(), []  # Nothing to build

    df = pd.DataFrame(rows, columns = cols)
    return df, cols

# ----------------- AE -----------------
'''
class AE(nn.Module):
    def __init__(self, d, lat = LATENT_DIM, drop = DROPOUT):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(d, 256), nn.ReLU(),
            nn.Dropout(drop),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, lat),
        )
        self.dec = nn.Sequential(
            nn.Linear(lat, 128), nn.ReLU(),
            nn.Linear(128, 256), nn.ReLU(),
            nn.Dropout(drop),
            nn.Linear(256, d),
        )

    def forward(self, x):
        return self.dec(self.enc(x))
'''
'''
class AE(nn.Module):
    def __init__(self, d, lat = LATENT_DIM, drop = DROPOUT):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(d, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64,  lat))
        self.dec = nn.Sequential(
            nn.Linear(lat, 64), nn.ReLU(),
            nn.Linear(64, 128), nn.ReLU(),
            nn.Linear(128, d))
    def forward(self, x):
        return self.dec(self.enc(x))
'''

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)




import torch.nn as nn

class AE(nn.Module):
    def __init__(self, d, lat=LATENT_DIM):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(d, 256), nn.BatchNorm1d(256), nn.LeakyReLU(0.01),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.01),
            nn.Linear(128, 64), nn.BatchNorm1d(64), nn.LeakyReLU(0.01),
            nn.Linear(64, lat)
        )
        self.dec = nn.Sequential(
            nn.Linear(lat, 64), nn.LeakyReLU(0.01),
            nn.Linear(64, 128), nn.LeakyReLU(0.01),
            nn.Linear(128, 256), nn.LeakyReLU(0.01),
            nn.Linear(256, d),
            nn.Tanh()  ##### new
        )

    def forward(self, x):
        z = self.enc(x)
        #print("Latent mean/std:", z.mean().item(), z.std().item())  # 🧪 Add this
        return self.dec(z)






def train_multi_logs(log_paths: List[str]):
    log_paths.sort()
    train_logs = log_paths[:90]
    val_logs = log_paths[90:]

    # Save log split manifest
    out_dir = pathlib.Path("ae_parameters_multi")
    out_dir.mkdir(exist_ok=True)
    json.dump({"train_logs": train_logs, "val_logs": val_logs}, open(out_dir / "split_manifest.json", "w"), indent=2)

    def extract_df(log_list):
        dfs = []
        for path in log_list:
            slows, fasts, imu = parse_log(path)
            df, cols = build(slows, fasts, imu)
            if not df.empty:
                dfs.append(df)
        return pd.concat(dfs, ignore_index=True)

    df_train = extract_df(train_logs)
    df_val = extract_df(val_logs)

    ###### Drop unstable features (tested to see htat these dont generalize well)
    cols_filtered = [c for c in df_train.columns if not any(k in c for k in (
        "skew", "kurt", "slow_pos", "slow_vel", "win_n_fast", "win_dur_s"))]

    def apply_jitter(X):
        return X + np.random.normal(0, JITTER_STD, X.shape).astype(np.float32)

    X_train = apply_jitter(df_train[cols_filtered].values.astype(np.float32)) #============================================
    #X_train = df_train[cols_filtered].values.astype(np.float32)  # no jitter
    X_val = df_val[cols_filtered].values.astype(np.float32)

    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)

    X_train_scaled = np.clip(X_train_scaled, -6.0, 6.0)
    X_val_scaled = np.clip(X_val_scaled, -6.0, 6.0)

    #### PCA
    pca_full = PCA().fit(X_train_scaled)
    k = min(
        np.searchsorted(np.cumsum(pca_full.explained_variance_ratio_), PCA_VAR) + 1,
        X_train_scaled.shape[1], int(PCA_MAX_FRAC * len(X_train_scaled))
    )

    print(f"\n PCA Feature Reduction:")
    print(f"Original feature count : {X_train_scaled.shape[1]}")
    print(f"Reduced feature count  : {int(k)} (explaining {PCA_VAR*100:.1f}% variance)")

    plt.plot(np.cumsum(pca_full.explained_variance_ratio_))
    plt.axhline(PCA_VAR, color="r", linestyle="--", label=f"{PCA_VAR:.2%} target")
    plt.xlabel("Number of PCA components")
    plt.ylabel("Cumulative explained variance")
    plt.grid(True)
    plt.legend()
    plt.title("Explained Variance vs PCA Components")
    plt.show()

    pca = PCA(n_components=k, whiten=True)  # Whitening OFF for now
    Z_train = pca.fit_transform(X_train_scaled)
    Z_val = pca.transform(X_val_scaled)

    tr_dl = DataLoader(TensorDataset(torch.tensor(Z_train)), batch_size=BATCH, shuffle=True)
    va_dl = DataLoader(TensorDataset(torch.tensor(Z_val)), batch_size=BATCH)

    ############ Building AE model
    ae = AE(int(k))
    ae.apply(init_weights)
    opt = torch.optim.AdamW(ae.parameters(), lr=LR, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS_MAX)
    crit = nn.SmoothL1Loss(beta=0.5)

    ###### Diagnostic (checking values) before training
    xb_sample = next(iter(tr_dl))[0]
    ae.eval()
    with torch.no_grad():
        out_sample = ae(xb_sample)
    print("Input mean/std :", xb_sample.mean().item(), xb_sample.std().item())
    print("Output mean/std:", out_sample.mean().item(), out_sample.std().item())
    print("First row MSE  :", torch.mean((xb_sample[0] - out_sample[0]) ** 2).item())
    print("Reconstruction delta:", torch.abs(out_sample - xb_sample).mean().item())
    plt.plot(xb_sample[0].cpu().numpy(), label='Input')
    plt.plot(out_sample[0].cpu().numpy(), label='Reconstructed')
    plt.legend()
    plt.show()


    best, bad = float("inf"), 0
    best_state = None
    train_losses, val_losses = [], []

    for ep in range(1, EPOCHS_MAX + 1):
        ae.train()
        tr_loss = sum(crit(ae(xb), xb).item() * xb.size(0) for xb, in tr_dl) / len(Z_train)
        ae.eval()
        va_loss = sum(crit(ae(xb), xb).item() * xb.size(0) for xb, in va_dl) / len(Z_val)
        train_losses.append(tr_loss)
        val_losses.append(va_loss)
        print(f"ep{ep:03d} tr={tr_loss:.4f} va={va_loss:.4f} bad={bad}/{PATIENCE}")
        if va_loss < best - 1e-5:
            best, bad = va_loss, 0
            best_state = ae.state_dict()
        else:
            bad += 1
            if bad >= PATIENCE:
                break
        scheduler.step()  ###### update learning rate



    ae.load_state_dict(best_state)
    ae.eval()
    Z_val_tensor = torch.tensor(Z_val)
    with torch.no_grad():
        mse_val = torch.mean((ae(Z_val_tensor) - Z_val_tensor) ** 2, dim=1).numpy()

    ##### Thresholds
    thr_med = float(np.median(mse_val))
    thr_p98 = float(np.percentile(mse_val, 98))
    thr_p99 = float(np.percentile(mse_val, 99))
    thr_mad = float(thr_med + 3 * robust.mad(mse_val))

    print("\n[Thresholds from validation set]")
    print("Median:", thr_med)
    print("p98   :", thr_p98)
    print("p99   :", thr_p99)
    print("MAD   :", thr_mad)

    #### Save parameters
    pickle.dump(scaler, open(out_dir / "scaler.pkl", "wb"))
    pickle.dump(pca, open(out_dir / "pca.pkl", "wb"))
    json.dump(cols_filtered, open(out_dir / "cols.json", "w"))
    json.dump(dict(thr_mad=thr_mad, thr_p98=thr_p98, thr_p99=thr_p99, thr_med=thr_med),
               open(out_dir / "thresholds.json", "w"))
    torch.jit.script(ae.eval()).save(out_dir / "ae.pt")
    print("\n✅ Saved model and thresholds to", out_dir)

    return {
        "mse": mse_val,
        "thr_med": thr_med,
        "thr_p98": thr_p98,
        "thr_p99": thr_p99,
        "thr_mad": thr_mad,
        "df": df_val.reset_index(drop=True)
    }

if __name__ == "__main__":
    log_paths = [f"clean_logs/normal_pose_append_log_{i:03d}.txt" for i in range(1, 101)]
    art = train_multi_logs(log_paths)

    # Attach MSE to dataframe
    df_new = art["df"]
    mse = art["mse"]
    df_new["mse"] = mse

    # Plot histogram
    plt.hist(mse, bins=100, color='skyblue', edgecolor='black')
    plt.axvline(art["thr_med"], color='green', linestyle='--', label='median')
    plt.axvline(art["thr_p98"], color='red', linestyle='--', label='p98')
    plt.axvline(art["thr_p99"], color='blue', linestyle='--', label='p99')
    plt.axvline(art["thr_mad"], color='yellow', linestyle='--', label='mad')
    plt.legend()
    plt.title("Reconstruction Error Histogram")
    plt.xlabel("MSE")
    plt.ylabel("Log Frequency")
    plt.yscale("log")
    plt.xlim(0, 2)
    plt.grid(True)
    plt.show()

    ##### Detect anomalies
    anomalies_p98 = np.where(mse > art["thr_p98"])[0]
    anomalies_p99 = np.where(mse > art["thr_p99"])[0]
    anomalies_mad = np.where(mse > art["thr_mad"])[0]

    print(f"\n[Anomalies]")
    print(f"p98 threshold: {len(anomalies_p98)} anomalies")
    print(f"p99 threshold: {len(anomalies_p99)} anomalies")
    print(f"MAD threshold: {len(anomalies_mad)} anomalies")

    ##### Hybrid labeling
    df_new["hybrid_anomaly_level"] = np.select(
        [
            df_new["mse"] > art["thr_p98"],
            df_new["mse"] > art["thr_mad"]
        ],
        ["hard", "soft"],
        default="normal"
    )

    ##### Count and show
    print("\n[Hybrid Anomaly Classification]")
    hybrid_counts = df_new["hybrid_anomaly_level"].value_counts()
    for level in ["normal", "soft", "hard"]:
        count = hybrid_counts.get(level, 0)
        print(f"  {level:>6}: {count} / {len(df_new)} ({100*count/len(df_new):.2f}%)")

    ##### Plot with seaborn
    sns.histplot(data=df_new, x="mse", hue="hybrid_anomaly_level", bins=100, log_scale=True)
    plt.axvline(art["thr_p98"], color="red", linestyle="--", label="p98")
    plt.axvline(art["thr_mad"], color="yellow", linestyle="--", label="mad")
    plt.legend()
    plt.title("MSE Histogram by Hybrid Anomaly Level")
    plt.xlabel("MSE")
    plt.ylabel("Log Count")
    plt.show()

    ##### Top anomalies
    print("\nTop 10 anomalies by MSE:")
    top_anomalies = df_new.sort_values(by="mse", ascending=False).head(10)
    print(top_anomalies[["mse"] + df_new.columns[:5].tolist()])



In [None]:
######## --- Load & preprocess log --- #########
log_path = "/home/jarvis/pose_append_log.csv"  #### Replace as needed
slows, fasts, imu = parse_log(log_path)
df_eval, _ = build(slows, fasts, imu)

import pickle
import json

##### Load scaler, PCA, and column list
with open("ae_parameters_multi/scaler.pkl", "rb") as f:
    scaler = pickle.load(f)

with open("ae_parameters_multi/pca.pkl", "rb") as f:
    pca = pickle.load(f)

with open("ae_parameters_multi/cols.json", "r") as f:
    cols_filtered = json.load(f)

ae = torch.jit.load("ae_parameters_multi/ae.pt")
ae.eval()


###### Ensure valid samples
if df_eval.empty:
    print("⚠️ No usable windows in the evaluation log.")
else:
    cols_filtered = [c for c in df_eval.columns if not any(k in c for k in (
        "skew", "kurt", "slow_pos", "slow_vel", "win_n_fast", "win_dur_s"))]
    #### Filter & scale
    X_eval = df_eval[cols_filtered].values.astype(np.float32)
    X_eval_scaled = scaler.transform(X_eval)
    X_eval_scaled = np.clip(X_eval_scaled, -6.0, 6.0)

    #### Apply PCA
    Z_eval = pca.transform(X_eval_scaled)

    #### Do Inference
    Z_eval_tensor = torch.tensor(Z_eval, dtype=torch.float32)
    ae.eval()
    with torch.no_grad():
        Z_recon = ae(Z_eval_tensor)
        mse_eval = torch.mean((Z_eval_tensor - Z_recon) ** 2, dim=1).numpy()

    ##### Classify anomalies using saved thresholds
    with open("ae_parameters_multi/thresholds.json", "r") as f:
        thr = json.load(f)

    df_eval["mse"] = mse_eval
    df_eval["hybrid_anomaly_level"] = np.select(
        [
            df_eval["mse"] > thr["thr_p98"],
            df_eval["mse"] > thr["thr_mad"]
        ],
        ["hard", "soft"],
        default="normal"
    )

    ##### Results summary ####
    counts = df_eval["hybrid_anomaly_level"].value_counts()
    total = len(df_eval)
    print(f"\n✅ Inference done on {log_path}")
    for level in ["normal", "soft", "hard"]:
        count = counts.get(level, 0)
        print(f"{level:>6}: {count:5d} / {total} ({100*count/total:.2f}%)")

    #### Plot MSE histogram
    import matplotlib.pyplot as plt
    import seaborn as sns

    sns.histplot(data=df_eval, x="mse", hue="hybrid_anomaly_level", bins=100, log_scale=True)
    plt.axvline(thr["thr_p98"], color="red", linestyle="--", label="p98")
    plt.axvline(thr["thr_mad"], color="yellow", linestyle="--", label="mad")
    plt.legend()
    plt.title("MSE Histogram (Eval Log)")
    plt.xlabel("MSE")
    plt.ylabel("Log Count")
    plt.grid(True)
    plt.show()

    #### Preview top anomalies ####
    top_anomalies = df_eval.sort_values(by="mse", ascending=False).head(10)
    display(top_anomalies[["mse"] + df_eval.columns[:5].tolist()])
