In [4]:
import os
import random
import numpy as np
import scipy.io
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, roc_curve, auc, classification_report
from sklearn.utils import shuffle
from sklearn.neural_network import MLPClassifier
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from read_eeg import get_eeg_info

In [6]:
# --- Parameters ---
data_folder = 'C:/Users/Haiya/Downloads/Data/Data/'
fs = 500
min_samples = int(0.5 * 60 * fs)
max_samples = int(1.0 * 60 * fs)
num_segments = 50 # extract 50 segments from each sample!

# --- File Loading ---
files = [f for f in os.listdir(data_folder) if f.endswith('.mat')]
ad_files = [f for f in files if '_ad' in f]
healthy_files = [f for f in files if '_ad' not in f]

random.seed(1)
random.shuffle(ad_files)
random.shuffle(healthy_files)

FileNotFoundError: [WinError 3] The system cannot find the path specified: 'C:/Users/Haiya/Downloads/Data/Data/'

In [None]:
def split_files(file_list):
    n = len(file_list)
    return file_list[:int(0.8*n)], file_list[int(0.8*n):int(0.9*n)], file_list[int(0.9*n):]

ad_train, ad_val, ad_test = split_files(ad_files)
hl_train, hl_val, hl_test = split_files(healthy_files)

train_files = ad_train + hl_train
val_files = ad_val + hl_val
test_files = ad_test + hl_test

# --- Helper Functions ---
def segment_subjects(file_list, folder, num_segments, min_len, max_len):
    segments = []
    labels = []
    for file in file_list:
        mat = scipy.io.loadmat(os.path.join(folder, file))
        eeg = mat['transferred_EEG']['trial'][0,0][0,0]  # first trial
        total_samples = eeg.shape[1]
        for _ in range(num_segments):
            length = random.randint(min_len, max_len)
            start = random.randint(0, total_samples - length)
            segment = eeg[:, start:start+length]
            segments.append(segment)
            labels.append(1 if '_ad' in file else 0)
    return segments, labels

bands = {'Delta': (1, 4), 'Theta': (4, 8), 'Alpha': (8, 13), 'Beta': (13, 30), 'Gamma': (30, 70)}

def extract_features(eeg_data, fs):
    from scipy.signal import welch
    features = []
    for seg in eeg_data:
        ch_features = []
        for ch in seg:
            f, pxx = welch(ch, fs=fs, window='hamming', nperseg=512, noverlap=256, nfft=512)
            total_power = np.trapz(pxx[(f >= 0.5) & (f <= 70)], f[(f >= 0.5) & (f <= 70)])
            p_norm = pxx / np.sum(pxx)

            band_feats = []
            for band, (low, high) in bands.items():
                band_power = np.trapz(pxx[(f >= low) & (f <= high)], f[(f >= low) & (f <= high)])
                band_feats.append(band_power / total_power)

            entropy = -np.sum(p_norm * np.log2(p_norm + np.finfo(float).eps))
            mean_freq = np.sum(f * pxx) / np.sum(pxx)
            theta_alpha = band_feats[1] / (band_feats[2] + np.finfo(float).eps)
            delta_alpha = band_feats[0] / (band_feats[2] + np.finfo(float).eps)

            ch_features.append(band_feats + [entropy, mean_freq, theta_alpha, delta_alpha])

        features.append(np.array(ch_features).flatten())
    return np.array(features)

In [None]:
# --- Process Data ---
train_data, train_labels = segment_subjects(train_files, data_folder, num_segments, min_samples, max_samples)
val_data, val_labels = segment_subjects(val_files, data_folder, num_segments, min_samples, max_samples)
test_data, test_labels = segment_subjects(test_files, data_folder, num_segments, min_samples, max_samples)

# Balance train set
train_data, train_labels = shuffle(train_data, train_labels, random_state=1)
ad_idx = [i for i, l in enumerate(train_labels) if l == 1]
hl_idx = [i for i, l in enumerate(train_labels) if l == 0]
min_class = min(len(ad_idx), len(hl_idx))
sel_idx = ad_idx[:min_class] + hl_idx[:min_class]
train_data = [train_data[i] for i in sel_idx]
train_labels = [train_labels[i] for i in sel_idx]

# Feature extraction
train_features = extract_features(train_data, fs)
val_features = extract_features(val_data, fs)
test_features = extract_features(test_data, fs)

scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

# --- Train Model ---
model = MLPClassifier(hidden_layer_sizes=(64, 32), alpha=0.015, learning_rate_init=2e-4,
                     max_iter=500, batch_size=32, early_stopping=True, random_state=1)
model.fit(train_features, train_labels)

# --- Evaluation ---
predicted_labels = model.predict(test_features)
accuracy = np.mean(predicted_labels == test_labels) * 100
print(f"Test Accuracy: {accuracy:.2f}%")

C = confusion_matrix(test_labels, predicted_labels)
sns.heatmap(C, annot=True, fmt='d', xticklabels=['Healthy', 'AD'], yticklabels=['Healthy', 'AD'])
plt.title("Confusion Matrix (Raw)")
plt.show()

C_percent = 100 * C / C.sum(axis=1, keepdims=True)
sns.heatmap(C_percent, annot=True, fmt='.1f', cmap='viridis', xticklabels=['Healthy', 'AD'], yticklabels=['Healthy', 'AD'])
plt.title("Confusion Matrix (%)")
plt.show()

# --- ROC Curve ---
scores = model.predict_proba(test_features)[:,1]
fpr, tpr, _ = roc_curve(test_labels, scores)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'AUC = {roc_auc:.2f}')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
plt.title('ROC Curve'); plt.legend(); plt.grid(True)
plt.show()

# --- Additional Metrics ---
print(classification_report(test_labels, predicted_labels, target_names=['Healthy', 'AD']))

# --- PCA and t-SNE Visualization ---
print("Visualizing features with PCA and t-SNE...")
pca = PCA(n_components=2)
pca_proj = pca.fit_transform(train_features)
plt.figure()
sns.scatterplot(x=pca_proj[:,0], y=pca_proj[:,1], hue=train_labels, palette='deep')
plt.title('PCA Projection of Features')
plt.grid(True)
plt.show()

tsne_proj = TSNE(n_components=2, perplexity=30, random_state=1).fit_transform(train_features)
plt.figure()
sns.scatterplot(x=tsne_proj[:,0], y=tsne_proj[:,1], hue=train_labels, palette='deep')
plt.title('t-SNE Projection of Features')
plt.grid(True)
plt.show()
