In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
import cv2
from collections import Counter
import glob


In [3]:

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

class EEGCNNClassifier:
    def __init__(self, schizophrenia_path, healthy_path, img_size=(224, 224)):
        self.schizophrenia_path = schizophrenia_path
        self.healthy_path = healthy_path
        self.img_size = img_size
        self.X = []
        self.y = []
        self.fold_results = []
        
    def load_and_preprocess_data(self):
        """Load and preprocess HHT plot images"""
        print("Loading and preprocessing data...")
        
        # Load schizophrenia images (label: 1)
        schizo_files = glob.glob(os.path.join(self.schizophrenia_path, "*.png"))
        for file_path in schizo_files:
            try:
                img = load_img(file_path, target_size=self.img_size)
                img_array = img_to_array(img) / 255.0  # Normalize to [0,1]
                self.X.append(img_array)
                self.y.append(1)  # Schizophrenia
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
        
        # Load healthy images (label: 0)
        healthy_files = glob.glob(os.path.join(self.healthy_path, "*.png"))
        for file_path in healthy_files:
            try:
                img = load_img(file_path, target_size=self.img_size)
                img_array = img_to_array(img) / 255.0  # Normalize to [0,1]
                self.X.append(img_array)
                self.y.append(0)  # Healthy
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
        
        # Convert to numpy arrays
        self.X = np.array(self.X)
        self.y = np.array(self.y)
        
        print(f"Loaded {len(self.X)} images")
        print(f"Schizophrenia samples: {np.sum(self.y == 1)}")
        print(f"Healthy samples: {np.sum(self.y == 0)}")
        print(f"Image shape: {self.X.shape}")
        
    def create_cnn_model(self):
        """Create CNN architecture for EEG classification"""
        model = Sequential([
            # First Conv Block
            Conv2D(32, (3, 3), activation='relu', input_shape=(*self.img_size, 3)),
            BatchNormalization(),
            MaxPooling2D(2, 2),
            
            # Second Conv Block
            Conv2D(64, (3, 3), activation='relu'),
            BatchNormalization(),
            MaxPooling2D(2, 2),
            
            # Third Conv Block
            Conv2D(128, (3, 3), activation='relu'),
            BatchNormalization(),
            MaxPooling2D(2, 2),
            
            # Fourth Conv Block
            Conv2D(256, (3, 3), activation='relu'),
            BatchNormalization(),
            MaxPooling2D(2, 2),
            
            # Flatten and Dense layers
            Flatten(),
            Dense(512, activation='relu'),
            Dropout(0.5),
            Dense(256, activation='relu'),
            Dropout(0.3),
            Dense(1, activation='sigmoid')  # Binary classification
        ])
        
        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='binary_crossentropy',
            metrics=['accuracy']
        )
        
        return model
    
    def create_data_augmentation(self):
        """Create data augmentation generator"""
        return ImageDataGenerator(
            rotation_range=10,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            zoom_range=0.1,
            fill_mode='nearest'
        )
    
    def train_with_cross_validation(self, n_splits=5, epochs=50):
        """Train model with 5-fold cross validation"""
        print(f"\nStarting {n_splits}-fold cross validation...")
        
        # Initialize StratifiedKFold
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        
        fold_accuracies = []
        best_accuracy = 0
        best_model_path = "best_eeg_cnn_model.h5"
        
        # Create results directory
        os.makedirs("results", exist_ok=True)
        
        for fold, (train_idx, val_idx) in enumerate(skf.split(self.X, self.y)):
            print(f"\n{'='*50}")
            print(f"FOLD {fold + 1}/{n_splits}")
            print(f"{'='*50}")
            
            # Split data for current fold
            X_train, X_val = self.X[train_idx], self.X[val_idx]
            y_train, y_val = self.y[train_idx], self.y[val_idx]
            
            print(f"Training samples: {len(X_train)} (Schizo: {np.sum(y_train)}, Healthy: {len(y_train) - np.sum(y_train)})")
            print(f"Validation samples: {len(X_val)} (Schizo: {np.sum(y_val)}, Healthy: {len(y_val) - np.sum(y_val)})")
            
            # Create model for this fold
            model = self.create_cnn_model()
            
            # Callbacks
            checkpoint = ModelCheckpoint(
                f'results/fold_{fold+1}_model.h5',
                monitor='val_accuracy',
                save_best_only=True,
                mode='max',
                verbose=1
            )
            
            early_stopping = EarlyStopping(
                monitor='val_loss',
                patience=10,
                restore_best_weights=True,
                verbose=1
            )
            
            reduce_lr = ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.2,
                patience=5,
                min_lr=1e-7,
                verbose=1
            )
            
            callbacks = [checkpoint, early_stopping, reduce_lr]
            
            # Data augmentation for training
            datagen = self.create_data_augmentation()
            
            # Train model
            history = model.fit(
                datagen.flow(X_train, y_train, batch_size=32),
                steps_per_epoch=len(X_train) // 32,
                epochs=epochs,
                validation_data=(X_val, y_val),
                callbacks=callbacks,
                verbose=1
            )
            
            # Load best model for this fold
            best_fold_model = load_model(f'results/fold_{fold+1}_model.h5')
            
            # Evaluate on validation set
            val_predictions = best_fold_model.predict(X_val)
            val_predictions_binary = (val_predictions > 0.5).astype(int).flatten()
            
            # Calculate accuracy
            fold_accuracy = accuracy_score(y_val, val_predictions_binary)
            fold_accuracies.append(fold_accuracy)
            
            print(f"\nFold {fold + 1} Validation Accuracy: {fold_accuracy:.4f}")
            
            # Save best overall model
            if fold_accuracy > best_accuracy:
                best_accuracy = fold_accuracy
                best_fold_model.save(best_model_path)
                print(f"New best model saved with accuracy: {best_accuracy:.4f}")
            
            # Generate confusion matrix for this fold
            cm = confusion_matrix(y_val, val_predictions_binary)
            
            # Plot confusion matrix
            plt.figure(figsize=(8, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                       xticklabels=['Healthy', 'Schizophrenia'],
                       yticklabels=['Healthy', 'Schizophrenia'])
            plt.title(f'Confusion Matrix - Fold {fold + 1}\nAccuracy: {fold_accuracy:.4f}')
            plt.ylabel('True Label')
            plt.xlabel('Predicted Label')
            plt.tight_layout()
            plt.savefig(f'results/confusion_matrix_fold_{fold+1}.png', dpi=300, bbox_inches='tight')
            plt.show()
            
            # Classification report
            print(f"\nClassification Report - Fold {fold + 1}:")
            print(classification_report(y_val, val_predictions_binary, 
                                      target_names=['Healthy', 'Schizophrenia']))
            
            # Store fold results
            self.fold_results.append({
                'fold': fold + 1,
                'accuracy': fold_accuracy,
                'confusion_matrix': cm,
                'history': history.history
            })
        
        # Calculate overall statistics
        mean_accuracy = np.mean(fold_accuracies)
        std_accuracy = np.std(fold_accuracies)
        
        print(f"\n{'='*60}")
        print(f"CROSS VALIDATION RESULTS")
        print(f"{'='*60}")
        print(f"Individual fold accuracies: {[f'{acc:.4f}' for acc in fold_accuracies]}")
        print(f"Mean Accuracy: {mean_accuracy:.4f} ± {std_accuracy:.4f}")
        print(f"Best Accuracy: {best_accuracy:.4f}")
        print(f"Best model saved as: {best_model_path}")
        
        return mean_accuracy, std_accuracy, best_accuracy
    
    def plot_training_history(self):
        """Plot training history for all folds"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        for i, result in enumerate(self.fold_results):
            history = result['history']
            fold = result['fold']
            
            # Plot accuracy
            axes[0, 0].plot(history['accuracy'], label=f'Fold {fold} Train')
            axes[0, 1].plot(history['val_accuracy'], label=f'Fold {fold} Val')
            
            # Plot loss
            axes[1, 0].plot(history['loss'], label=f'Fold {fold} Train')
            axes[1, 1].plot(history['val_loss'], label=f'Fold {fold} Val')
        
        axes[0, 0].set_title('Training Accuracy')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        axes[0, 1].set_title('Validation Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        axes[1, 0].set_title('Training Loss')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Loss')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        axes[1, 1].set_title('Validation Loss')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Loss')
        axes[1, 1].legend()
        axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig('results/training_history.png', dpi=300, bbox_inches='tight')
        plt.show()

# Usage example
if __name__ == "__main__":
    # Update these paths to your actual data directories
    SCHIZOPHRENIA_PATH = "D:/HHT/S"  # Replace with your path
    HEALTHY_PATH = "D:/HHT/H"  # Replace with your path
    
    # Initialize classifier
    classifier = EEGCNNClassifier(
        schizophrenia_path=SCHIZOPHRENIA_PATH,
        healthy_path=HEALTHY_PATH,
        img_size=(224, 224)  # Adjust based on your image size
    )
    
    # Load and preprocess data
    classifier.load_and_preprocess_data()
    
    # Train with 5-fold cross validation
    mean_acc, std_acc, best_acc = classifier.train_with_cross_validation(
        n_splits=5, 
        epochs=50
    )
    
    # Plot training history
    classifier.plot_training_history()
    
    print(f"\nFinal Results:")
    print(f"Mean CV Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
    print(f"Best Model Accuracy: {best_acc:.4f}")
    print(f"Best model saved as: best_eeg_cnn_model.h5")

Loading and preprocessing data...
Loaded 9381 images
Schizophrenia samples: 5146
Healthy samples: 4235
Image shape: (9381, 224, 224, 3)

Starting 5-fold cross validation...

FOLD 1/5
Training samples: 7504 (Schizo: 4116, Healthy: 3388)
Validation samples: 1877 (Schizo: 1030, Healthy: 847)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  self._warn_if_super_not_called()


Epoch 1/50
[1m100/234[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m5:26[0m 2s/step - accuracy: 0.5083 - loss: 6.5250

: 