# Vibration Fault Detection — ML Training

Train tree-based classifiers on 35 firmware-matched features extracted from
raw 3-axis acceleration segments.  Designed to run on **Google Colab** (no GPU needed).

In [None]:
# ── Cell 1: Setup ─────────────────────────────────────────────────────────
# Install dependencies (uncomment on Colab)
# !pip install -q numpy scipy pandas scikit-learn lightgbm xgboost matplotlib seaborn datasets

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    ConfusionMatrixDisplay,
    roc_curve,
    auc,
)
from sklearn.preprocessing import label_binarize
from sklearn.ensemble import RandomForestClassifier
import lightgbm as lgb
import xgboost as xgb
import pickle, os, warnings

warnings.filterwarnings('ignore')
sns.set_style('whitegrid')
print('Setup complete.')

In [None]:
# ── Cell 2: Configuration ─────────────────────────────────────────────────

# ── Data source: choose ONE ──
# Option A: HuggingFace dataset (for Colab)
DATA_SOURCE = 'huggingface'
HF_DATASET = 'adyady/bearing-fault-dataset'
SIGNAL_FIELD = 'high_data'  # which array to use: 'high_data' or 'low_data'

# Option B: Local JSON files
# DATA_SOURCE = 'local'
# DATA_DIR = 'data/raw'

# Pre-computed feature CSV (skip extraction if it exists)
FEATURE_CSV = 'features.csv'

# Class grouping: map raw fault_category → training label
CLASS_MAP = {
    'healthy':         'Healthy',
    'bearing_inner':   'Bearing Fault',
    'bearing_outer':   'Bearing Fault',
    'bearing_rolling': 'Bearing Fault',
    'electrical':      'Electrical Fault',
    'flow_cavitation': 'Flow/Cavitation',
    'unbalance':       'Unbalance',
    'misalignment':    'Misalignment',
    'gear_fault':      'Gear Fault',
}

# 35 feature columns (canonical order from feature_extraction.py)
FEATURE_COLS = [
    'temp',
    'xRMS', 'yRMS', 'zRMS',
    'xVRMS', 'yVRMS', 'zVRMS',
    'xEnvRMS', 'yEnvRMS', 'zEnvRMS',
    'xKU', 'yKU', 'zKU',
    'xP2P', 'yP2P', 'zP2P',
    'maxCf',
    'accLowPeakRatioX', 'accLowPeakRatioY', 'accLowPeakRatioZ',
    'accMidPeakRatioX', 'accMidPeakRatioY', 'accMidPeakRatioZ',
    'accHighPeakRatioX', 'accHighPeakRatioY', 'accHighPeakRatioZ',
    'velLowPeakRatioX', 'velLowPeakRatioY', 'velLowPeakRatioZ',
    'velMidPeakRatioX', 'velMidPeakRatioY', 'velMidPeakRatioZ',
    'velHighPeakRatioX', 'velHighPeakRatioY', 'velHighPeakRatioZ',
]

RANDOM_STATE = 42
print(f'Config: {len(FEATURE_COLS)} features, source={DATA_SOURCE}')

In [None]:
# ── Cell 3: Load Data ─────────────────────────────────────────────────────

import sys, json

# --- Clone repo for feature_extraction.py if running on Colab ---
if DATA_SOURCE == 'huggingface' and not os.path.exists('feature_extraction.py'):
    print('feature_extraction.py not found locally.')
    print('Paste it into Colab or upload it, then re-run this cell.')
    print('(Or clone your repo: !git clone <your-repo-url>)')

from feature_extraction import extract_features

if os.path.exists(FEATURE_CSV):
    print(f'Loading pre-computed features from {FEATURE_CSV}')
    df = pd.read_csv(FEATURE_CSV)

elif DATA_SOURCE == 'huggingface':
    from datasets import load_dataset

    print(f'Loading HuggingFace dataset: {HF_DATASET} ...')
    ds = load_dataset(HF_DATASET, split='train')
    hf_df = ds.to_pandas()
    print(f'  Loaded {len(hf_df)} rows (each row = 1 axis of 1 segment)')
    print(f'  Columns: {list(hf_df.columns)}')
    print(f'  Axes: {hf_df["axis"].value_counts().to_dict()}')
    print(f'  Fault categories: {hf_df["fault_category"].value_counts().to_dict()}')

    # ── Group per-axis rows into segments ──
    # Each unique file_name may have rows for axis x, y, z
    grouped = hf_df.groupby('file_name')
    print(f'\n  Unique segments: {len(grouped)}')
    print(f'  Extracting features (using "{SIGNAL_FIELD}" column) ...')

    rows = []
    for i, (seg_name, seg_rows) in enumerate(grouped):
        try:
            # Collect axes for this segment
            axes_data = {}
            for _, row in seg_rows.iterrows():
                ax = row['axis']  # 'x', 'y', or 'z'
                signal = np.asarray(row[SIGNAL_FIELD], dtype=np.float64)
                axes_data[ax] = signal

            x = axes_data.get('x')
            y = axes_data.get('y')
            z = axes_data.get('z')

            if x is None:
                continue

            # Sample rate from the row
            first_row = seg_rows.iloc[0]
            fs = float(first_row.get('target_sample_rate') or first_row.get('original_sample_rate'))

            # RPM from metadata_json if available
            rpm = None
            meta_str = first_row.get('metadata_json', '{}')
            if isinstance(meta_str, str) and meta_str:
                try:
                    meta = json.loads(meta_str)
                    rpm = meta.get('rpm') or meta.get('operating_conditions', {}).get('rpm')
                except (json.JSONDecodeError, TypeError):
                    pass
            if rpm is not None:
                rpm = float(rpm)

            feats = extract_features(x, y, z, fs=fs, rpm=rpm)
            feats['filename'] = seg_name
            feats['fault_category'] = first_row.get('fault_category', '')
            feats['fault_type'] = first_row.get('fault_type', '')
            feats['dataset'] = first_row.get('source_dataset', '')
            feats['sample_rate_hz'] = fs
            rows.append(feats)

        except Exception as e:
            if i < 5:
                print(f'    SKIP {seg_name}: {e}')

        if (i + 1) % 2000 == 0:
            print(f'    Processed {i + 1}/{len(grouped)} segments')

    df = pd.DataFrame(rows)
    df.to_csv(FEATURE_CSV, index=False)
    print(f'\n  Extracted {len(df)} segments → saved to {FEATURE_CSV}')

else:  # local JSON files
    print(f'Running feature extraction on {DATA_DIR} ...')
    from feature_extraction import extract_all
    df = extract_all(DATA_DIR, output_csv=FEATURE_CSV)

# ── Map fault_category → class label ──
df['label'] = df['fault_category'].map(CLASS_MAP)
unknown = df['label'].isna().sum()
if unknown > 0:
    unmapped = df.loc[df['label'].isna(), 'fault_category'].unique()
    print(f'WARNING: {unknown} rows have unmapped fault_category: {unmapped}')
    print('         These rows will be dropped.  Update CLASS_MAP to include them.')
    df = df.dropna(subset=['label'])

print(f'\nDataset: {len(df)} segments, {df["label"].nunique()} classes')
print(df['label'].value_counts())

In [None]:
# ── Cell 4: EDA ───────────────────────────────────────────────────────────

fig, axes = plt.subplots(1, 3, figsize=(20, 5))

# 4a. Class distribution
order = df['label'].value_counts().index
sns.countplot(data=df, y='label', order=order, ax=axes[0], hue='label', legend=False)
axes[0].set_title('Class Distribution')
axes[0].set_xlabel('Count')

# 4b. Feature correlation heatmap (numeric features only, drop all-NaN cols)
feat_present = [c for c in FEATURE_COLS if c in df.columns and df[c].notna().any()]
corr = df[feat_present].corr()
sns.heatmap(corr, ax=axes[1], cmap='coolwarm', center=0,
            xticklabels=False, yticklabels=False, cbar_kws={'shrink': 0.6})
axes[1].set_title('Feature Correlation')

# 4c. Box plot of xRMS per class
if 'xRMS' in df.columns:
    sns.boxplot(data=df, y='label', x='xRMS', order=order, ax=axes[2], hue='label', legend=False)
    axes[2].set_title('xRMS by Class')

plt.tight_layout()
plt.show()

# 4d. Additional box plots for key features
key_features = ['xVRMS', 'xKU', 'xP2P', 'maxCf']
key_features = [f for f in key_features if f in df.columns and df[f].notna().any()]
if key_features:
    fig, axes2 = plt.subplots(1, len(key_features), figsize=(5 * len(key_features), 4))
    if len(key_features) == 1:
        axes2 = [axes2]
    for ax, feat in zip(axes2, key_features):
        sns.boxplot(data=df, y='label', x=feat, order=order, ax=ax, hue='label', legend=False)
        ax.set_title(feat)
    plt.tight_layout()
    plt.show()

In [None]:
# ── Cell 5: Preprocessing ─────────────────────────────────────────────────

# Select feature columns that actually exist in the DataFrame
use_cols = [c for c in FEATURE_COLS if c in df.columns]
X = df[use_cols].copy()
y = df['label'].copy()

# Train / test split (80/20, stratified)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=RANDOM_STATE, stratify=y
)

print(f'Features used: {len(use_cols)}')
print(f'Train: {len(X_train)}  |  Test: {len(X_test)}')
print(f'NaN fraction in train: {X_train.isna().mean().mean():.2%}')
print('\nNo scaling applied — tree models handle raw values and NaN natively.')

In [None]:
# ── Cell 6: Model Training ────────────────────────────────────────────────

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)
class_labels = sorted(y.unique())

# ── 6a. LightGBM (primary) ──
lgb_model = lgb.LGBMClassifier(
    n_estimators=300,
    learning_rate=0.05,
    max_depth=7,
    num_leaves=63,
    class_weight='balanced',
    random_state=RANDOM_STATE,
    verbosity=-1,
)
lgb_cv = cross_val_score(lgb_model, X_train, y_train, cv=cv, scoring='accuracy')
lgb_model.fit(X_train, y_train)
print(f'LightGBM  5-fold CV accuracy: {lgb_cv.mean():.4f} ± {lgb_cv.std():.4f}')

# ── 6b. XGBoost (comparison) ──
# Compute sample weights for class imbalance
from sklearn.utils.class_weight import compute_sample_weight
sw_train = compute_sample_weight('balanced', y_train)

xgb_model = xgb.XGBClassifier(
    n_estimators=300,
    learning_rate=0.05,
    max_depth=7,
    random_state=RANDOM_STATE,
    eval_metric='mlogloss',
    verbosity=0,
)
xgb_cv = cross_val_score(xgb_model, X_train, y_train, cv=cv, scoring='accuracy',
                          fit_params={'sample_weight': sw_train})
xgb_model.fit(X_train, y_train, sample_weight=sw_train)
print(f'XGBoost   5-fold CV accuracy: {xgb_cv.mean():.4f} ± {xgb_cv.std():.4f}')

# ── 6c. Random Forest (baseline) ──
rf_model = RandomForestClassifier(
    n_estimators=300,
    max_depth=20,
    class_weight='balanced',
    random_state=RANDOM_STATE,
    n_jobs=-1,
)
rf_cv = cross_val_score(rf_model, X_train, y_train, cv=cv, scoring='accuracy')
rf_model.fit(X_train, y_train)
print(f'RandomForest 5-fold CV accuracy: {rf_cv.mean():.4f} ± {rf_cv.std():.4f}')

# Pick best model
results = {
    'LightGBM': (lgb_model, lgb_cv.mean()),
    'XGBoost': (xgb_model, xgb_cv.mean()),
    'RandomForest': (rf_model, rf_cv.mean()),
}
best_name = max(results, key=lambda k: results[k][1])
best_model = results[best_name][0]
print(f'\nBest model: {best_name} ({results[best_name][1]:.4f} CV accuracy)')

In [None]:
# ── Cell 7: Evaluation ────────────────────────────────────────────────────

for name, (model, _) in results.items():
    y_pred = model.predict(X_test)
    acc = (y_pred == y_test).mean()
    print(f'\n{"=" * 60}')
    print(f'{name}  —  Test Accuracy: {acc:.4f}')
    print(f'{"=" * 60}')
    print(classification_report(y_test, y_pred, zero_division=0))

# Confusion matrix for best model
y_pred_best = best_model.predict(X_test)
cm = confusion_matrix(y_test, y_pred_best, labels=class_labels)
fig, ax = plt.subplots(figsize=(8, 6))
ConfusionMatrixDisplay(cm, display_labels=class_labels).plot(ax=ax, cmap='Blues', colorbar=False)
ax.set_title(f'Confusion Matrix — {best_name}')
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
plt.show()

# ROC curves (one-vs-rest)
if hasattr(best_model, 'predict_proba'):
    y_test_bin = label_binarize(y_test, classes=class_labels)
    y_score = best_model.predict_proba(X_test)
    fig, ax = plt.subplots(figsize=(8, 6))
    for i, cls in enumerate(class_labels):
        if y_test_bin.shape[1] > i:
            fpr, tpr, _ = roc_curve(y_test_bin[:, i], y_score[:, i])
            ax.plot(fpr, tpr, label=f'{cls} (AUC={auc(fpr, tpr):.3f})')
    ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
    ax.set_xlabel('FPR')
    ax.set_ylabel('TPR')
    ax.set_title(f'ROC Curves — {best_name}')
    ax.legend(loc='lower right', fontsize=8)
    plt.tight_layout()
    plt.show()

In [None]:
# ── Cell 8: Feature Importance ────────────────────────────────────────────

importance = lgb_model.feature_importances_
feat_imp = pd.Series(importance, index=use_cols).sort_values(ascending=False)

top_n = min(20, len(feat_imp))
fig, ax = plt.subplots(figsize=(8, 6))
feat_imp.head(top_n).plot.barh(ax=ax)
ax.invert_yaxis()
ax.set_title(f'Top-{top_n} Feature Importance (LightGBM)')
ax.set_xlabel('Importance (split count)')
plt.tight_layout()
plt.show()

print('\nTop-20 features:')
print(feat_imp.head(20).to_string())

In [None]:
# ── Cell 9: Save Model ────────────────────────────────────────────────────

artifact = {
    'model': best_model,
    'model_name': best_name,
    'feature_names': use_cols,
    'class_labels': class_labels,
    'class_map': CLASS_MAP,
    'cv_accuracy': results[best_name][1],
}

model_path = 'best_model.pkl'
with open(model_path, 'wb') as f:
    pickle.dump(artifact, f)

print(f'Saved {best_name} to {model_path}')
print(f'  Features : {len(use_cols)}')
print(f'  Classes  : {class_labels}')
print(f'  CV acc   : {results[best_name][1]:.4f}')