In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import requests
from scipy.io import loadmat
import torch
import urllib.request

# Set seed

In [2]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    
set_seed(42)

# Download & extract the files

In [3]:
# List of file links
CWRU_links = {
    'normal_0': 'https://engineering.case.edu/sites/default/files/97.mat',
    'normal_1': 'https://engineering.case.edu/sites/default/files/98.mat',
    'normal_2': 'https://engineering.case.edu/sites/default/files/99.mat',
    'normal_3': 'https://engineering.case.edu/sites/default/files/100.mat',
    
    '12k_DE_IR007_0': 'https://engineering.case.edu/sites/default/files/105.mat',
    '12k_DE_IR007_1': 'https://engineering.case.edu/sites/default/files/106.mat',
    '12k_DE_IR007_2': 'https://engineering.case.edu/sites/default/files/107.mat',
    '12k_DE_IR007_3': 'https://engineering.case.edu/sites/default/files/108.mat',
    
    '12k_DE_B007_0': 'https://engineering.case.edu/sites/default/files/118.mat',
    '12k_DE_B007_1': 'https://engineering.case.edu/sites/default/files/119.mat',
    '12k_DE_B007_2': 'https://engineering.case.edu/sites/default/files/120.mat',
    '12k_DE_B007_3': 'https://engineering.case.edu/sites/default/files/121.mat',
    
    '12k_DE_OR007@6_0': 'https://engineering.case.edu/sites/default/files/130.mat',
    '12k_DE_OR007@6_1': 'https://engineering.case.edu/sites/default/files/131.mat',
    '12k_DE_OR007@6_2': 'https://engineering.case.edu/sites/default/files/132.mat',
    '12k_DE_OR007@6_3': 'https://engineering.case.edu/sites/default/files/133.mat',
    
    
    '12k_DE_IR014_0': 'https://engineering.case.edu/sites/default/files/169.mat',
    '12k_DE_IR014_1': 'https://engineering.case.edu/sites/default/files/170.mat',
    '12k_DE_IR014_2': 'https://engineering.case.edu/sites/default/files/171.mat',
    '12k_DE_IR014_3': 'https://engineering.case.edu/sites/default/files/172.mat',
    
    '12k_DE_B014_0': 'https://engineering.case.edu/sites/default/files/185.mat',
    '12k_DE_B014_1': 'https://engineering.case.edu/sites/default/files/186.mat',
    '12k_DE_B014_2': 'https://engineering.case.edu/sites/default/files/187.mat',
    '12k_DE_B014_3': 'https://engineering.case.edu/sites/default/files/188.mat',
    
    '12k_DE_OR014@6_0': 'https://engineering.case.edu/sites/default/files/197.mat',
    '12k_DE_OR014@6_1': 'https://engineering.case.edu/sites/default/files/198.mat',
    '12k_DE_OR014@6_2': 'https://engineering.case.edu/sites/default/files/199.mat',
    '12k_DE_OR014@6_3': 'https://engineering.case.edu/sites/default/files/200.mat',
    
    
    '12k_DE_IR021_0': 'https://engineering.case.edu/sites/default/files/209.mat',
    '12k_DE_IR021_1': 'https://engineering.case.edu/sites/default/files/210.mat',
    '12k_DE_IR021_2': 'https://engineering.case.edu/sites/default/files/211.mat',
    '12k_DE_IR021_3': 'https://engineering.case.edu/sites/default/files/212.mat',
    
    '12k_DE_B021_0': 'https://engineering.case.edu/sites/default/files/222.mat',
    '12k_DE_B021_1': 'https://engineering.case.edu/sites/default/files/223.mat',
    '12k_DE_B021_2': 'https://engineering.case.edu/sites/default/files/224.mat',
    '12k_DE_B021_3': 'https://engineering.case.edu/sites/default/files/225.mat',
    
    '12k_DE_OR021@6_0': 'https://engineering.case.edu/sites/default/files/234.mat',
    '12k_DE_OR021@6_1': 'https://engineering.case.edu/sites/default/files/235.mat',
    '12k_DE_OR021@6_2': 'https://engineering.case.edu/sites/default/files/236.mat',
    '12k_DE_OR021@6_3': 'https://engineering.case.edu/sites/default/files/237.mat',
}

def download_and_extract(file_name, url, folder_path, dtype, extract_function):
    while True:
        try:
            print(f"Downloading {url}")
            urllib.request.urlretrieve(url, os.path.join(folder_path, f'{file_name}{dtype}'))
            print(f'Extracting {file_name}{dtype}')
            extract_function(folder_path, file_name)
            break  # Exit the loop if the download is successful
        except Exception as e:
            print(f"Failed to download {url}: {e}")

def extract_nothing(folder, file_name):
    pass

# Download & Extract CWRU dataset
folder_path = os.path.join(os.getcwd(), 'CWRU')
os.makedirs(folder_path, exist_ok=True)
for file_name, url_link in CWRU_links.items():
    if not os.path.exists(os.path.join(folder_path, file_name)):
        download_and_extract(file_name, url_link, folder_path, '.mat', extract_nothing)

Downloading https://engineering.case.edu/sites/default/files/97.mat
Extracting normal_0.mat
Downloading https://engineering.case.edu/sites/default/files/98.mat
Extracting normal_1.mat
Downloading https://engineering.case.edu/sites/default/files/99.mat
Extracting normal_2.mat
Downloading https://engineering.case.edu/sites/default/files/100.mat
Extracting normal_3.mat
Downloading https://engineering.case.edu/sites/default/files/105.mat
Extracting 12k_DE_IR007_0.mat
Downloading https://engineering.case.edu/sites/default/files/106.mat
Extracting 12k_DE_IR007_1.mat
Downloading https://engineering.case.edu/sites/default/files/107.mat
Extracting 12k_DE_IR007_2.mat
Downloading https://engineering.case.edu/sites/default/files/108.mat
Extracting 12k_DE_IR007_3.mat
Downloading https://engineering.case.edu/sites/default/files/118.mat
Extracting 12k_DE_B007_0.mat
Downloading https://engineering.case.edu/sites/default/files/119.mat
Extracting 12k_DE_B007_1.mat
Downloading https://engineering.case.ed

# Data Preprocessing

In [9]:
# source_lables_dict = {
#     0: ['normal_0'],
#     1: ['12k_DE_IR007_0'],
#     2: ['12k_DE_B007_0'],
#     3: ['12k_DE_OR007@6_0']
# }

# target_lables_dict = {
#     0: ['normal_1', 'normal_2', 'normal_3'],
#     1: ['12k_DE_IR007_1', '12k_DE_IR007_2', '12k_DE_IR007_3', '12k_DE_IR014_1', '12k_DE_IR014_2', '12k_DE_IR014_3', '12k_DE_IR021_1', '12k_DE_IR021_2', '12k_DE_IR021_3'],
#     2: ['12k_DE_B007_1', '12k_DE_B007_2', '12k_DE_B007_3', '12k_DE_B014_1', '12k_DE_B014_2', '12k_DE_B014_3', '12k_DE_B021_1', '12k_DE_B021_2', '12k_DE_B021_3'],
#     3: ['12k_DE_OR007@6_1', '12k_DE_OR007@6_2', '12k_DE_OR007@6_3', '12k_DE_OR014@6_1', '12k_DE_OR014@6_2', '12k_DE_OR014@6_3', '12k_DE_OR021@6_1', '12k_DE_OR021@6_2', '12k_DE_OR021@6_3']
# }


In [12]:
class_lables_dict = {
    0: ['normal_0', 'normal_1', 'normal_2', 'normal_3'],
    1: ['12k_DE_IR007_0', '12k_DE_IR007_1', '12k_DE_IR007_2', '12k_DE_IR007_3'],
    2: ['12k_DE_B007_0', '12k_DE_B007_1', '12k_DE_B007_2', '12k_DE_B007_3'],
    3: ['12k_DE_OR007@6_0', '12k_DE_OR007@6_1', '12k_DE_OR007@6_2', '12k_DE_OR007@6_3'],
    4: ['12k_DE_IR014_0', '12k_DE_IR014_1', '12k_DE_IR014_2', '12k_DE_IR014_3'],
    5: ['12k_DE_B014_0', '12k_DE_B014_1', '12k_DE_B014_2', '12k_DE_B014_3'],
    6: ['12k_DE_OR014@6_0', '12k_DE_OR014@6_1', '12k_DE_OR014@6_2', '12k_DE_OR014@6_3'],
    7: ['12k_DE_IR021_0', '12k_DE_IR021_1', '12k_DE_IR021_2', '12k_DE_IR021_3'],
    8: ['12k_DE_B021_0', '12k_DE_B021_1', '12k_DE_B021_2', '12k_DE_B021_3'],
    9: ['12k_DE_OR021@6_0', '12k_DE_OR021@6_1', '12k_DE_OR021@6_2', '12k_DE_OR021@6_3']
}

folder_path = os.path.join(os.getcwd(), "CWRU")

def read_dict(mat_dict):
    x, y = [], []
    for label, file_list in mat_dict.items():
        x_tensor, y_tensor = read_list(file_list, label)
        x.append(x_tensor)
        y.append(y_tensor)
        
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)
    
    return x, y

def read_list(file_list, label):
    x, y = [], []
    for file_name in file_list:
        file_path = os.path.join(folder_path, file_name)
        column_name = get_column_name(file_name) # Get the original file name (97.mat, 105.mat etc ..)
        data = loadmat(file_path)
        
        DE_Channel = data[f'X{column_name}_DE_time']
        FE_Channel = data[f'X{column_name}_FE_time']
        
        combined_channels = np.stack((DE_Channel.squeeze(), FE_Channel.squeeze()), axis=0)
        combined_tensor = torch.tensor(combined_channels)
        
        sample_tensor = sliding_window_subsample(combined_tensor, window_size=1024, step=1024)
        label_tensor = labels = torch.full((sample_tensor.shape[0],), label)
        x.append(sample_tensor)
        y.append(label_tensor)
        
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)
    
    return x, y

def get_column_name(file_name):
    url = CWRU_links[file_name]
    column_name = url.split('/')[-1].split('.')[0]
    if len(column_name) < 3:
        column_name = '0' + column_name
    return column_name

def sliding_window_subsample(tensor, window_size=1024, step=1024):
    tensor = tensor.unsqueeze(1)
    return tensor.unfold(2, window_size, step).transpose(0, 1).transpose(1, 2).squeeze(0)

x, y = read_dict(class_lables_dict)
print(x.shape, y.shape)

torch.Size([5927, 2, 1024]) torch.Size([5927])


In [14]:
def train_test_split(x, y): # Split the tensor into training, validation and testing
    dataset = torch.utils.data.TensorDataset(x, y) # Combine x and y to ensure both are split in the same way
    
    total_size = len(dataset)
    train_size = int(0.6 * total_size)
    val_size = int(0.2 * total_size)
    test_size = total_size - train_size - val_size

    # Split the dataset
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
    
    # Split x and y to maintain consistency with other dataset
    train = split_xy(train_dataset)
    val = split_xy(val_dataset)
    test = split_xy(test_dataset)
    
    return train, val, test 

# Split x and y to maintain consistency with other dataset
def split_xy(dataset):
    x, y = [], []
    for x_tensor, y_tensor in dataset:
        x.append(x_tensor)
        y.append(y_tensor)
    # Convert lists to tensors
    x = torch.stack(x)
    y = torch.stack(y)
    print(x.shape, y.shape)
    
    return {"samples": x,  "labels": y}

training, validation, testing = train_test_split(x, y)

torch.Size([3556, 2, 1024]) torch.Size([3556])
torch.Size([1185, 2, 1024]) torch.Size([1185])
torch.Size([1186, 2, 1024]) torch.Size([1186])


In [15]:
# Save the datasets
torch.save(training, os.path.join(os.getcwd(), 'CWRU', 'train.pt'))
torch.save(validation, os.path.join(os.getcwd(), 'CWRU', 'val.pt'))
torch.save(testing, os.path.join(os.getcwd(), 'CWRU', 'test.pt'))