In [6]:
import os
import numpy as np
import matplotlib.pyplot as plt
import requests
import zipfile
import io
import mne
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv1D, MaxPooling1D
from tensorflow.keras.utils import to_categorical

# Configuration
DATA_DIR = './eeg_data'
SAMPLE_SUBJECT = 'S001'
DOWNLOAD_URL = 'https://physionet.org/files/eegmmidb/1.0.0/'
NUM_CHANNELS = 64
NUM_CLASSES = 4  # Motor imagery tasks: left hand, right hand, both hands, feet

# Create data directory if it doesn't exist
os.makedirs(DATA_DIR, exist_ok=True)
print(f"Data directory: {DATA_DIR}")

def download_sample_data():
    """Download sample data from PhysioNet."""
    subject_url = f"{DOWNLOAD_URL}{SAMPLE_SUBJECT}.zip"
    zip_path = os.path.join(DATA_DIR, f"{SAMPLE_SUBJECT}.zip")
    
    # Check if data already exists
    if os.path.exists(zip_path):
        print("Sample data already downloaded, skipping download")
        return
    
    print(f"Downloading sample data from {subject_url}...")
    
    try:
        response = requests.get(subject_url)
        response.raise_for_status()
        
        # Save the zip file
        with open(zip_path, 'wb') as f:
            f.write(response.content)
        
        print("Download complete, extracting files...")
        
        # Extract the zip file
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(DATA_DIR)
        
        print("Extraction complete")
    except Exception as e:
        print(f"Error downloading or extracting data: {e}")
        raise

def load_edf_files():
    """
    Load EDF files for the sample subject.
    In a real implementation, you would parse all the EDF files.
    For this example, we'll create synthetic data.
    """
    print("Loading EEG data...")
    
    # For demonstration purposes, we'll create synthetic data
    # In a real implementation, you would use mne to load the EDF files:
    # raw = mne.io.read_raw_edf(file_path, preload=True)
    
    # Simulate loading 100 trials of EEG data
    num_trials = 100
    time_points = 200  # 200 time points per trial
    
    # Create synthetic data: [trials, channels, time points]
    X = np.random.randn(num_trials, NUM_CHANNELS, time_points)
    
    # Create synthetic labels (0: left hand, 1: right hand, 2: both hands, 3: feet)
    y = np.random.randint(0, NUM_CLASSES, size=num_trials)
    
    print("Data loading complete")
    print(f"X shape: {X.shape}")
    print(f"y shape: {y.shape}")
    
    return X, y

def preprocess_data(X, y):
    """Preprocess the EEG data."""
    print("Preprocessing EEG data...")
    
    num_trials, num_channels, time_points = X.shape
    
    # For a simple CNN that doesn't take time into account, we'll reshape to treat each time point as a separate sample
    # Reshape X to [trials * timePoints, channels, 1]
    X_reshaped = X.reshape(num_trials * time_points, num_channels, 1)
    
    # Repeat labels for each time point
    y_reshaped = np.repeat(y, time_points)
    
    print("Data preprocessing complete")
    print(f"X_reshaped shape: {X_reshaped.shape}")
    print(f"y_reshaped shape: {y_reshaped.shape}")
    
    return X_reshaped, y_reshaped

def split_data(X, y):
    """Split data into training and testing sets."""
    print("Splitting data into training and testing sets...")
    
    # Split the data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    # One-hot encode the labels
    y_train_onehot = to_categorical(y_train, NUM_CLASSES)
    y_test_onehot = to_categorical(y_test, NUM_CLASSES)
    
    print("Data split complete")
    print(f"Training set: {X_train.shape[0]} samples")
    print(f"Testing set: {X_test.shape[0]} samples")
    
    return X_train, y_train_onehot, X_test, y_test_onehot, y_test

def build_cnn_model():
    """Build a simple CNN model."""
    print("Building CNN model...")
    
    model = Sequential([
        # First convolutional layer
        Conv1D(
            filters=32,
            kernel_size=3,
            activation='relu',
            padding='same',
            input_shape=(NUM_CHANNELS, 1)
        ),
        MaxPooling1D(pool_size=2, strides=2),
        
        # Second convolutional layer
        Conv1D(
            filters=64,
            kernel_size=3,
            activation='relu',
            padding='same'
        ),
        MaxPooling1D(pool_size=2, strides=2),
        
        # Flatten and dense layers
        Flatten(),
        Dense(128, activation='relu'),
        Dropout(0.5),
        
        # Output layer
        Dense(NUM_CLASSES, activation='softmax')
    ])
    
    # Compile the model
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    model.summary()
    
    return model

def train_model(model, X_train, y_train, X_test, y_test):
    """Train the CNN model."""
    print("Training the model...")
    
    # Define training parameters
    batch_size = 32
    epochs = 10
    
    # Train the model
    history = model.fit(
        X_train, y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_data=(X_test, y_test),
        verbose=1
    )
    
    print("Model training complete")
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    return model, history

def evaluate_model(model, X_test, y_test_onehot, y_test):
    """Evaluate the trained model."""
    print("Evaluating the model...")
    
    # Evaluate the model
    loss, accuracy = model.evaluate(X_test, y_test_onehot, verbose=0)
    print(f"Test accuracy: {accuracy:.4f}")
    
    # Make predictions
    y_pred_prob = model.predict(X_test)
    y_pred = np.argmax(y_pred_prob, axis=1)
    
    # Calculate confusion matrix
    cm = confusion_matrix(y_test, y_pred)
    
    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    
    classes = ['Left Hand', 'Right Hand', 'Both Hands', 'Feet']
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    # Add text annotations
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    # Print classification report
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=classes))

def main():
    """Main function to run the EEG classifier pipeline."""
    # Step 1: Download sample data
    download_sample_data()
    
    # Step 2: Load EEG data
    X, y = load_edf_files()
    
    # Step 3: Preprocess the data
    X_processed, y_processed = preprocess_data(X, y)
    
    # Step 4: Split data into training and testing sets
    X_train, y_train, X_test, y_test_onehot, y_test = split_data(X_processed, y_processed)
    
    # Step 5: Build the CNN model
    model = build_cnn_model()
    
    # Step 6: Train the model
    model, history = train_model(model, X_train, y_train, X_test, y_test_onehot)
    
    # Step 7: Evaluate the model
    evaluate_model(model, X_test, y_test_onehot, y_test)
    
    # Save the model
    model_save_path = os.path.join(DATA_DIR, 'eeg_cnn_model.h5')
    model.save(model_save_path)
    print(f"Model saved to {model_save_path}")

# Run the main function
if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'mne'