In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# import torchvision
# import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
from Efficient_Wav_FastKAN import *
import scipy.io as sio
import numpy as np
from sklearn.preprocessing import StandardScaler,MinMaxScaler  # Used for standardized processing
from sklearn.metrics import confusion_matrix # Calculate the sensitivity and specificity
import openpyxl
import time

In [9]:
## 
from contrast_mlpLayers import *
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch import nn, optim
from sklearn.model_selection import KFold
from scipy.io import loadmat
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix

# Set random seeds to ensure repeatability
torch.manual_seed(42)
np.random.seed(42)

# 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
class PatientDataset(Dataset):
    """Customize the PyTorch dataset"""
    def __init__(self, data, labels, transform=None):
        self.data = torch.tensor(data, dtype=torch.float32).to(device)
        self.labels = torch.tensor(labels, dtype=torch.long).to(device)
        self.transform = transform
 
    def __len__(self):
        return len(self.data)
 
    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample, label

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, epochs=30):
    """Train and validate the model"""
    all_val_acc = []
    all_val_sen = []
    all_val_spc = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
 
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
 
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            train_loss += loss.item() * inputs.size(0)
        train_acc = 100 * train_correct / train_total 
        train_loss = train_loss / len(train_loader.dataset)
        
        # Learning rate scheduling
        scheduler.step()        
 
        if epoch % 10 == 0 or epoch == epochs-1:
            print(f"Epoch {epoch+1}/{epochs} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Train Acc: {train_acc:.2f}% | "

                 )

    # Validation
    
    # Save the relevant parameters of the trained model
    torch.save(model,'KAN_variant.pth')
    # Model instantiation
    model_test = torch.load('KAN_variant.pth')
    
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    all_predicted = []
    all_labels = []
    inference_time = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            start = time.time()
            outputs = model(inputs)
            end = time.time()
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
            val_loss += loss.item() * inputs.size(0)

            all_predicted.extend(predicted.cpu().numpy().ravel().tolist())
            all_labels.extend(labels.cpu().numpy().ravel().tolist())
            
            inference_time.append((end-start)*1000)
    
    # Calculate the average criteria   
    val_loss = val_loss / len(val_loader.dataset)   
    val_acc = 100 * val_correct / val_total
    cm = confusion_matrix(all_labels, all_predicted)
    tn, fp,  fn, tp = cm.ravel()
    val_sen = tp / (tp + fn)
    val_spc = tn / (tn + fp)
    val_pr = tp / (tp + fp)
    val_F1 = 2 * val_pr * val_sen / (val_pr + val_sen)
    
    return val_acc, val_sen*100, val_spc*100, np.mean(inference_time)
 
def process_patient(patient_dir, pat_id, epochs=30, batch_size=64):
    """Process individual patient data"""
    try:
        # Load data
        combined = loadmat(os.path.join(patient_dir, f'data_combined_shuffled_{pat_id}.mat'))['data_combined_shuffled'].astype(np.float32)
        X = combined[:, :-1].astype(np.float32)
        y = combined[:, -1].astype(np.int64)
 
        # 5-fold cross validation
        kf = KFold(n_splits=5, shuffle=True, random_state=42)
        fold_acc = []
        fold_sen = []
        fold_spc = []
        fold_infer_time = []
 
        for fold, (train_idx, val_idx) in enumerate(kf.split(X)):
            print(f"\nFold {fold+1}")
            
            # Split data
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            
            # Standardization
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_val_scaled = scaler.transform(X_val)
           
 
            # Create a data set
            train_dataset = PatientDataset(X_train_scaled, y_train)
            val_dataset = PatientDataset(X_val_scaled, y_val)
 
            # Create loader
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
 
            # Initialize the model
            # Wav-KAN(Wav) or FastKAN(FastKAN) or Efficient-KAN(Spline)
            input_dim = X_train_scaled.shape[1]
            model = KAN([input_dim, 512, 2], wavelet_type='dog', kan_type='Wav')             
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
            scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)
            
            # Training model
            best_acc, best_sen, best_spc, inference_time = train_model(model, train_loader, val_loader, 
                                 criterion, optimizer, scheduler, epochs)
            
            fold_acc.append(best_acc)
            fold_sen.append(best_sen)
            fold_spc.append(best_spc)
            fold_infer_time.append(inference_time)
 
        # Output the patient results
        print(f"\n{'='*30}\n Chb: {pat_id}")
        print(f"Accuracy of the 5-fold cross validation: {[f'{acc:.2f}' for acc in fold_acc]}")
        print(f"Sensitivity of the 5-fold cross validation: {[f'{sen:.2f}' for sen in fold_sen]}")
        print(f"Specificity of the 5-fold cross validation: {[f'{spc:.2f}' for spc in fold_spc]}")
        
        print(f"Average_accuracy: {np.mean(fold_acc):.2f} ± {np.std(fold_acc):.3f}")
        print(f"Average_sensitivity: {np.mean(fold_sen):.2f} ± {np.std(fold_sen):.3f}")
        print(f"Average_specificity: {np.mean(fold_spc):.2f} ± {np.std(fold_spc):.3f}")
        print(f"Average_inference_time：{np.mean(fold_infer_time):.2f}ms")
        
        return np.mean(fold_acc), np.mean(fold_sen), np.mean(fold_spc), np.mean(fold_infer_time)
 
    except Exception as e:
        print(f"\nHanding {patient_dir} has a error: {str(e)}")
        return None

def main(root_dir, epochs=30, batch_size=64):
    """Main processing function"""
 
    all_results = []
    all_results_sen = []
    all_results_spc = []
    all_results_infer_time = []
    
    pd = root_dir
    for pat_id in range(1,30):
        result, result_sen, result_spc, result_infer_time = process_patient(pd, pat_id, epochs, batch_size)
        if result is not None:
            all_results.append(result)
            all_results_sen.append(result_sen)
            all_results_spc.append(result_spc)
            all_results_infer_time.append(result_infer_time)
 
    print(f"\n{'='*40}")
    print(f"Overall average accuracy: {np.mean(all_results):.4f} ± {np.std(all_results):.4f}")
    print(f"Overall average sensitivity: {np.mean(all_results_sen):.4f} ± {np.std(all_results_sen):.4f}")
    print(f"Overall average specificity: {np.mean(all_results_spc):.4f} ± {np.std(all_results_spc):.4f}")
    print(f"Overall average inference time: {np.mean(all_results_infer_time):.4f}ms")

if __name__ == "__main__":
    # Configuration parameters
    ROOT_DIR = "E:/BaiduSyncdisk/EEG/Huashan_data"
    EPOCHS = 30
    BATCH_SIZE = 64
 
    print(f"Device: {device}")
    main(ROOT_DIR, EPOCHS, BATCH_SIZE)

Device: cpu

Fold 1
Epoch 1/30 | Train Loss: 0.4833 | Train Acc: 76.30% | 
Epoch 11/30 | Train Loss: 0.1394 | Train Acc: 99.48% | 
Epoch 21/30 | Train Loss: 0.1241 | Train Acc: 100.00% | 
Epoch 30/30 | Train Loss: 0.1079 | Train Acc: 100.00% | 

Fold 2
Epoch 1/30 | Train Loss: 0.3162 | Train Acc: 84.90% | 
Epoch 11/30 | Train Loss: 0.1344 | Train Acc: 99.22% | 
Epoch 21/30 | Train Loss: 0.1065 | Train Acc: 99.74% | 
Epoch 30/30 | Train Loss: 0.1072 | Train Acc: 99.74% | 

Fold 3
Epoch 1/30 | Train Loss: 0.4139 | Train Acc: 81.51% | 
Epoch 11/30 | Train Loss: 0.1359 | Train Acc: 98.70% | 
Epoch 21/30 | Train Loss: 0.1197 | Train Acc: 99.74% | 
Epoch 30/30 | Train Loss: 0.1088 | Train Acc: 99.48% | 

Fold 4
Epoch 1/30 | Train Loss: 0.2858 | Train Acc: 87.50% | 
Epoch 11/30 | Train Loss: 0.1292 | Train Acc: 99.48% | 
Epoch 21/30 | Train Loss: 0.1103 | Train Acc: 100.00% | 
Epoch 30/30 | Train Loss: 0.1003 | Train Acc: 100.00% | 

Fold 5
Epoch 1/30 | Train Loss: 0.3521 | Train Acc: 83.07% 

KeyboardInterrupt: 

In [6]:
# Calculate the number of model parameters (Parameters)
from thop import clever_format, profile

model_complexity =torch.load('KAN_variant.pth')
total_params = sum(p.numel() for p in model_complexity.parameters())
trainable_params = sum(p.numel() for p in model_complexity.parameters() if p.requires_grad)
print(f"Total Parameters: {total_params/1e6:.2f}M")
print(f"Trainable Parameters: {trainable_params/1e6:.2f}M")


Total Parameters: 0.56M
Trainable Parameters: 0.56M
