In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy.signal import butter, filtfilt, detrend, windows, welch
from scipy.stats import skew, kurtosis
import os

# Configuration
FS = 12000  # Sampling Frequency

## 1. Load Data
We load sample data from the CWRU dataset. Ensure the `.mat` files are in the `../data/` directory.

In [None]:
def load_data(filename):
    path = os.path.join("../data", filename)
    if not os.path.exists(path):
        print(f"File not found: {path}")
        return None
    data = loadmat(path)
    # Find the time series key (usually ends with _DE_time)
    for key in data.keys():
        if key.endswith("_DE_time"):
            return data[key].flatten()
    return None

# Load samples
normal_sig = load_data("97.mat")
inner_sig = load_data("inner_race.mat")
ball_sig = load_data("ball_fault.mat")
outer_sig = load_data("outer_race.mat")

# Slice first 2048 points for visualization
N = 2048
if normal_sig is not None: normal_sig = normal_sig[:N]
if inner_sig is not None: inner_sig = inner_sig[:N]
if ball_sig is not None: ball_sig = ball_sig[:N]
if outer_sig is not None: outer_sig = outer_sig[:N]

## 2. Preprocessing
Raw signals often contain noise and DC offsets. We apply:
- **Low-pass Filter:** Remove high-frequency noise.
- **Detrending:** Remove linear trends.
- **Normalization:** Scale to zero mean and unit variance.

In [None]:
def preprocess_signal(signal):
    # 1. Low-pass filter
    b, a = butter(4, 0.2, btype='low')
    filtered = filtfilt(b, a, signal)
    
    # 2. Detrend
    detrended = detrend(filtered)
    
    # 3. Normalize
    normalized = (detrended - np.mean(detrended)) / np.std(detrended)
    return normalized

# Apply preprocessing
normal_proc = preprocess_signal(normal_sig) if normal_sig is not None else None
inner_proc = preprocess_signal(inner_sig) if inner_sig is not None else None

## 3. Visualization (Time Domain)
Comparing Raw vs Preprocessed signals.

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(normal_sig)
plt.title("Raw Normal Signal")
plt.grid(True)

plt.subplot(2, 1, 2)
plt.plot(normal_proc)
plt.title("Preprocessed Normal Signal")
plt.grid(True)
plt.tight_layout()
plt.show()

## 4. Frequency Domain Analysis (FFT & PSD)
Faults are often easier to detect in the frequency domain.

In [None]:
def plot_spectrum(signal, title):
    # FFT
    fft_vals = np.abs(np.fft.rfft(signal))
    freqs = np.fft.rfftfreq(len(signal), d=1/FS)
    
    # PSD
    f, psd = welch(signal, fs=FS, nperseg=1024)
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(freqs, fft_vals)
    plt.title(f"{title} - FFT")
    plt.xlabel("Frequency (Hz)")
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.semilogy(f, psd)
    plt.title(f"{title} - PSD")
    plt.xlabel("Frequency (Hz)")
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

if inner_proc is not None:
    plot_spectrum(inner_proc, "Inner Race Fault")

## 5. Feature Extraction
We extract statistical features for the Machine Learning model.

In [None]:
def extract_features(signal):
    feats = {}
    # Time Domain
    feats["RMS"] = np.sqrt(np.mean(signal**2))
    feats["Kurtosis"] = kurtosis(signal)
    feats["Skewness"] = skew(signal)
    feats["Peak_to_Peak"] = np.ptp(signal)
    
    # Frequency Domain
    fft_vals = np.abs(np.fft.rfft(signal))
    freqs = np.fft.rfftfreq(len(signal), 1/FS)
    idx_max = np.argmax(fft_vals)
    feats["Dominant_Freq"] = freqs[idx_max]
    feats["Dominant_Amp"] = fft_vals[idx_max]
    
    return feats

if inner_proc is not None:
    print("Features for Inner Race Fault:")
    print(extract_features(inner_proc))