In [None]:
import os, re, json, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import LeaveOneGroupOut

import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use("seaborn-v0_8-darkgrid")
warnings.filterwarnings("ignore")

DATA_DIR = "./Data"     
OUTPUT_DIR = "./Outputs2"
os.makedirs(OUTPUT_DIR, exist_ok=True)

WINDOW_SIZE = 120          # 2s @ 60Hz
STRIDE = 1                 # stride-1: 1-120, 2-121, ...
EPOCHS = 30
BATCH_SIZE = 64

REQUIRE_CONSTANT_LABEL = True  #drop any window that spans label changes
USE_NORMALIZATION = True       #True= fixed input z-norm baked into model

FEATURE_COLS = ["ax", "ay", "az", "gx", "gy", "gz"]
LABEL_COL = "handedness"
TIME_COL = "timestamp"  


ACCEPTED_LABELS = {"left": "LEFT", "right": "RIGHT", "both": "BOTH"}

def _standardize_cols(df: pd.DataFrame) -> pd.DataFrame:
    return df.rename(columns={c: c.strip().lower() for c in df.columns})

def _label_from_filename(stem: str):
    s = stem.lower()
    for tok in ["left", "right", "both"]:
        if re.search(rf"(^|[_\-\s]){tok}([_\-\s]|$)", s):
            return ACCEPTED_LABELS[tok]
    return None

def load_imu_csv(path: Path) -> pd.DataFrame:
    df = pd.read_csv(path)
    df = _standardize_cols(df)
    required = set([TIME_COL, LABEL_COL, *FEATURE_COLS])
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"{path.name}: missing columns: {sorted(list(missing))}")
   
    df[LABEL_COL] = (df[LABEL_COL].astype(str)
                     .str.strip().str.lower()
                     .map(ACCEPTED_LABELS).fillna(df[LABEL_COL]))
    
    inferred = _label_from_filename(path.stem)
    if inferred is not None:
        uniq = df[LABEL_COL].unique()
        if len(uniq) == 1 and uniq[0] != inferred:
            print(f"[WARN] Label mismatch {path.name}: in-file={uniq[0]} vs name={inferred}")
    return df

def make_stride_windows(df: pd.DataFrame,
                        file_id: str,
                        window_size: int,
                        stride: int,
                        require_constant_label: bool = True):
    
    n = len(df)
    if n < window_size:
        return np.empty((0, window_size, len(FEATURE_COLS)), dtype=np.float32), np.array([]), []

    feats = df[FEATURE_COLS].to_numpy(dtype=np.float32)
    labels = df[LABEL_COL].to_numpy()

    X_list, y_list, g_list = [], [], []
    for start in range(0, n - window_size + 1, stride):
        end = start + window_size
        y_win = labels[start:end]
        if require_constant_label:
            if not (y_win == y_win[0]).all():
                continue
            y_label = y_win[0]
        else:
            vals, counts = np.unique(y_win, return_counts=True)
            y_label = vals[np.argmax(counts)]
        X_list.append(feats[start:end])
        y_list.append(y_label)
        g_list.append(file_id)

    if not X_list:
        return np.empty((0, window_size, len(FEATURE_COLS)), dtype=np.float32), np.array([]), []
    return np.stack(X_list, axis=0), np.array(y_list), g_list

def load_and_window_all(data_dir: str, window_size: int, stride: int,
                        require_constant_label: bool = True):
    all_X, all_y, groups = [], [], []
    files = sorted([f for f in Path(data_dir).glob("*.csv")])
    if not files:
        raise FileNotFoundError(f"No CSV files found in {data_dir}")

    for f in tqdm(files, desc="Loading CSVs"):
        df = load_imu_csv(f)
        Xf, yf, gf = make_stride_windows(df, f.stem, window_size, stride, require_constant_label)
        if Xf.size == 0:  #skip too-short or all-mixed windows
            continue
        all_X.append(Xf)
        all_y.append(yf)
        groups.extend(gf)

    X = np.concatenate(all_X, axis=0) if all_X else np.empty((0, window_size, len(FEATURE_COLS)), dtype=np.float32)
    y = np.concatenate(all_y, axis=0) if all_y else np.array([])
    groups = np.array(groups)
    return X, y, groups


def build_cnn_with_fixed_norm_and_calibration(input_shape, num_classes, mean, std):
    
    #input_shape (T, C) like (120, 6)
    
    
    C = input_shape[-1]
    mean = np.asarray(mean, dtype="float32").reshape((C,))
    std  = np.asarray(std,  dtype="float32").reshape((C,))
    std  = np.maximum(std, 1e-8)
    var  = (std ** 2).astype("float32")

    #fixed normalization
    norm = layers.Normalization(axis=-1, name="fixed_normalization")
    norm.build((None,) + tuple(input_shape))  

    
    current_weights = norm.get_weights()
    if len(current_weights) == 2:
        
        mean_t = mean.astype("float32").reshape(current_weights[0].shape)
        var_t  = var.astype("float32").reshape(current_weights[1].shape)
        norm.set_weights([mean_t, var_t])
    elif len(current_weights) == 3:
       
        mean_t = mean.astype("float32").reshape(current_weights[0].shape)
        var_t  = var.astype("float32").reshape(current_weights[1].shape)
       
        if current_weights[2].shape == ():
            count_t = np.float32(1.0)  #scalar
        else:
            count_t = np.ones(current_weights[2].shape, dtype="float32")
        norm.set_weights([mean_t, var_t, count_t])
    else:
        raise ValueError(
            f"Unexpected number of Normalization weights: {len(current_weights)}; "
            "expected 2 or 3."
        )
    norm.trainable = False

    #learnable per-channel calibration / light mixing
    calib = layers.Conv1D(
        filters=C, kernel_size=1, padding='same',
        use_bias=True, name="calibration_conv1x1"
    )

    model = models.Sequential([
        layers.Input(shape=input_shape),
        norm,
        calib,

        layers.Conv1D(64, 5, padding='same', activation=None),
        layers.BatchNormalization(), layers.Activation('relu'),

        layers.Conv1D(64, 5, padding='same', activation=None),
        layers.BatchNormalization(), layers.Activation('relu'),
        layers.MaxPooling1D(2),

        layers.Conv1D(128, 5, padding='same', activation=None),
        layers.BatchNormalization(), layers.Activation('relu'),

        layers.GlobalAveragePooling1D(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.3),
        layers.Dense(num_classes, activation='softmax'),
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model




X_raw, y_str, groups = load_and_window_all(
    DATA_DIR, WINDOW_SIZE, STRIDE, require_constant_label=REQUIRE_CONSTANT_LABEL
)
print(f"Raw windows: X={X_raw.shape}, y={y_str.shape}, unique labels={np.unique(y_str)}")

#Encode labels
encoder = LabelEncoder()
y = encoder.fit_transform(y_str)
num_classes = len(encoder.classes_)
print("Label mapping:", dict(zip(encoder.classes_, range(num_classes))))


logo = LeaveOneGroupOut()
accuracies, histories = [], []
fold = 1

def fit_norm_from_train(X_train):
  
    mean = X_train.reshape(-1, X_train.shape[-1]).mean(axis=0)
    std = X_train.reshape(-1, X_train.shape[-1]).std(axis=0)
    std = np.where(std < 1e-8, 1.0, std)
    return mean.astype(np.float32), std.astype(np.float32)

for train_idx, test_idx in tqdm(list(logo.split(X_raw, y, groups)), desc="Cross-validation folds"):
    X_train_raw, X_test_raw = X_raw[train_idx], X_raw[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]

    if USE_NORMALIZATION:
        mean, std = fit_norm_from_train(X_train_raw)
    else:
        
        mean = np.zeros(X_train_raw.shape[-1], dtype=np.float32)
        std  = np.ones (X_train_raw.shape[-1], dtype=np.float32)

   
    model = build_cnn_with_fixed_norm_and_calibration(
        input_shape=X_train_raw.shape[1:],
        num_classes=num_classes,
        mean=mean,
        std=std
    )

    
    ckpt_path = f"{OUTPUT_DIR}/fold_{fold}_best.keras"
    cbs = [
        callbacks.ModelCheckpoint(ckpt_path, monitor='val_accuracy', save_best_only=True, verbose=0),
        callbacks.EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True)
    ]

    hist = model.fit(
        X_train_raw, y_train,
        validation_split=0.2,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        verbose=0,
        callbacks=cbs
    )
    histories.append(hist)

    loss, acc = model.evaluate(X_test_raw, y_test, verbose=0)
    accuracies.append(acc)
    print(f"❖ Fold {fold}: Test Accuracy = {acc:.3f}")
    fold += 1

print(f"\n Mean Accuracy = {np.mean(accuracies):.3f} ️ {np.std(accuracies):.3f}")


plt.figure(figsize=(8,5))
plt.bar(range(1, len(accuracies)+1), accuracies, color='skyblue')
plt.title("Per-Fold Accuracy (Group CV by File)")
plt.xlabel("Fold")
plt.ylabel("Accuracy")
plt.ylim(0, 1)
for i, acc in enumerate(accuracies):
    plt.text(i+1, min(acc+0.02, 0.98), f"{acc:.2f}", ha='center')
plt.savefig(f"{OUTPUT_DIR}/fold_accuracy.png", dpi=300)
plt.show()

last_hist = histories[-1].history
plt.figure(figsize=(10,5))
plt.plot(last_hist['accuracy'], label='Train Acc', linewidth=2)
plt.plot(last_hist['val_accuracy'], label='Val Acc', linewidth=2)
plt.plot(last_hist['loss'], label='Train Loss', linestyle='--')
plt.plot(last_hist['val_loss'], label='Val Loss', linestyle='--')
plt.title("Training Progress (Last Fold)")
plt.xlabel("Epoch")
plt.ylabel("Value")
plt.legend()
plt.savefig(f"{OUTPUT_DIR}/training_curves_last_fold.png", dpi=300)
plt.show()


if USE_NORMALIZATION:
    final_mean = X_raw.reshape(-1, X_raw.shape[-1]).mean(axis=0).astype(np.float32)
    final_std  = X_raw.reshape(-1, X_raw.shape[-1]).std(axis=0).astype(np.float32)
    final_std  = np.where(final_std < 1e-8, 1.0, final_std).astype(np.float32)
else:
    final_mean = np.zeros(X_raw.shape[-1], dtype=np.float32)
    final_std  = np.ones (X_raw.shape[-1], dtype=np.float32)

final_model = build_cnn_with_fixed_norm_and_calibration(
    input_shape=(WINDOW_SIZE, len(FEATURE_COLS)),
    num_classes=num_classes,
    mean=final_mean,
    std=final_std
)

cb_final = [
    callbacks.ModelCheckpoint(f"{OUTPUT_DIR}/final_best.keras", monitor='val_accuracy', save_best_only=True),
    callbacks.EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True)
]


final_history = final_model.fit(
    X_raw, y,
    validation_split=0.2,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    verbose=1,
    callbacks=cb_final
)


plt.figure(figsize=(10,5))
plt.plot(final_history.history['accuracy'], label='Train Acc', linewidth=2)
plt.plot(final_history.history['val_accuracy'], label='Val Acc', linewidth=2)
plt.plot(final_history.history['loss'], label='Train Loss', linestyle='--')
plt.plot(final_history.history['val_loss'], label='Val Loss', linestyle='--')
plt.title("Final Model Training")
plt.xlabel("Epoch")
plt.ylabel("Value")
plt.legend()
plt.savefig(f"{OUTPUT_DIR}/final_training_curve.png", dpi=300)
plt.show()

y_pred = np.argmax(final_model.predict(X_raw, verbose=0), axis=1)
cm = confusion_matrix(y, y_pred)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=encoder.classes_, yticklabels=encoder.classes_)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix (Final Model)")
plt.savefig(f"{OUTPUT_DIR}/confusion_matrix.png", dpi=300)
plt.show()

print(classification_report(y, y_pred, target_names=encoder.classes_))


with open(f"{OUTPUT_DIR}/metadata.json", "w") as f:
    json.dump({
        "label_classes": encoder.classes_.tolist(),
        "feature_order": FEATURE_COLS,
        "window_size": WINDOW_SIZE,
        "stride_for_training": STRIDE,
        "normalization_used": USE_NORMALIZATION,
        "mean": final_mean.tolist(),
        "std":  final_std.tolist()
    }, f, indent=2)

tflite_path = f"{OUTPUT_DIR}/imu_cnn_model.tflite"
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
tflite_model = converter.convert()
with open(tflite_path, "wb") as f:
    f.write(tflite_model)

print(f"Exported TFLite model to {tflite_path}")
print(f"Metadata saved to {OUTPUT_DIR}/metadata.json")
