In [None]:
import sys,os
sys.path.append('..')

In [None]:
from src.main_CWRU import evaluate_split, get_model, main_split

In [None]:
import pandas as pd
import numpy as np
import random
import torch

def set_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

import scipy.io as sio
import pickle
import os

mat_files = {
    "Normal_0.mat": 0,
    "IR014_0.mat": 1,
    "OR014@6_0.mat": 2,
    "B014_0.mat": 3
}

mat_keys = {
    "Normal_0.mat": "X097_DE_time",
    "IR014_0.mat": "X169_DE_time",
    "OR014@6_0.mat": "X197_DE_time",
    "B014_0.mat": "X185_DE_time"
}

def extract_and_segment_mat(file_path, key, label, segment_length=1500):
    mat = sio.loadmat(file_path)
    signal = mat[key].squeeze()
    total_segments = len(signal) // segment_length
    signal = signal[:total_segments * segment_length]
    segments = signal.reshape(total_segments, segment_length)
    labels = np.full((total_segments,), label)
    return segments, labels

X_list = []
y_list = []

for file, label in mat_files.items():
    segments, labels = extract_and_segment_mat(file, mat_keys[file], label)
    X_list.append(segments)
    y_list.append(labels)

X_raw = np.vstack(X_list)
y_raw = np.hstack(y_list)
shuffle_idx = np.random.permutation(len(y_raw))
X_raw = X_raw[shuffle_idx]
y_raw = y_raw[shuffle_idx]

X_all_combined = X_raw[:, np.newaxis, :]
y_all_combined = y_raw

save_path = "cwru_dataset_combined.pkl"
with open(save_path, "wb") as f:
    pickle.dump({
        "X_all_combined": X_all_combined,
        "y_all_combined": y_all_combined
    }, f)

print(f"saved: {os.path.abspath(save_path)}")
print(f"X_all_combined shape: {X_all_combined.shape}")
print(f"y_all_combined shape: {y_all_combined.shape}")
print("label:", np.bincount(y_all_combined))

In [None]:
main_split(X_all_combined, y_all_combined, model_type="ptfm", seed=42)

In [None]:
main_split(X_all_combined, y_all_combined, model_type="fft_mlp", seed=42)

In [None]:
main_split(X_all_combined, y_all_combined, model_type="cnn", seed=42)

In [None]:
main_split(X_all_combined, y_all_combined, model_type="vit_tiny", seed=42)

In [None]:
main_split(X_all_combined, y_all_combined, model_type="deepbdc", seed=42)