In [19]:
import pandas as pd
import numpy as np

data = pd.read_csv("../data/adhdata.csv")

# EEG columns
eeg_columns = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T7', 'T8', 'P7', 'P8', 'Fz', 'Cz', 'Pz']
raw_eeg = data[eeg_columns].values.T  # shape: (channels, samples)
raw_labels = data['Class'].tolist()

In [20]:
#filter and segment
from scipy.signal import butter, filtfilt

def bandpass_filter(signal, low=1, high=30, fs=128):
    nyq = 0.5 * fs
    b, a = butter(5, [low/nyq, high/nyq], btype='band')
    return filtfilt(b, a, signal)

def segment_eeg(eeg, segment_len=384, step=128):
    segments = []
    for start in range(0, eeg.shape[1] - segment_len + 1, step):
        segments.append(eeg[:, start:start+segment_len])
    return segments

In [21]:
#split in segments

from scipy.signal import butter, filtfilt

def bandpass_filter(signal, low=1, high=30, fs=128):
    nyq = 0.5 * fs
    b, a = butter(5, [low/nyq, high/nyq], btype='band')
    return filtfilt(b, a, signal)

def segment_eeg(eeg, segment_len=384, step=128):
    segments = []
    for start in range(0, eeg.shape[1] - segment_len + 1, step):
        segments.append(eeg[:, start:start+segment_len])
    return segments


In [22]:
#make cwt spectrograms

import pywt
import numpy as np
import cv2

def compute_cwt_spectrogram(segment, wavelet='morl'):
    spec = []
    for ch in segment:
        coef, _ = pywt.cwt(ch, scales=np.arange(1, 64), wavelet=wavelet)
        spec.append(np.abs(coef))
    spec = np.stack(spec, axis=-1)  # shape: (scales, time, channels)
    return cv2.resize(spec, (224, 224))  # ResNet input size

In [23]:
from tqdm import tqdm

def create_dataset(eeg, labels):
    X, y = [], []
    eeg_filtered = np.array([bandpass_filter(ch) for ch in eeg])
    segments = segment_eeg(eeg_filtered)
    segment_len = 384
    step = 128
    for i, seg in enumerate(tqdm(segments, desc="Creating dataset")):
        spec = compute_cwt_spectrogram(seg)
        start = i * step
        end = start + segment_len
        segment_labels = labels[start:end]
        if segment_labels:
            from collections import Counter
            label = Counter(segment_labels).most_common(1)[0][0]
            X.append(spec)
            y.append(label)
    return np.array(X), np.array(y)

X, y = create_dataset(raw_eeg, raw_labels)
X = X / X.max()  # normalize


Creating dataset:  27%|██▋       | 4499/16922 [05:36<15:28, 13.38it/s]  


SystemError: <built-in function resize> returned a result with an exception set

In [None]:
# Train ResNet model

from sklearn.model_selection import train_test_split
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input
from tensorflow.keras.optimizers import Adam

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)

def build_resnet(input_shape):
    base_model = ResNet50(include_top=False, weights=None, input_shape=input_shape)
    x = GlobalAveragePooling2D()(base_model.output)
    x = Dense(64, activation='relu')(x)
    out = Dense(2, activation='softmax')(x)
    return Model(inputs=base_model.input, outputs=out)

model = build_resnet((224, 224, X.shape[3]))
model.compile(optimizer=Adam(1e-4), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_split=0.2)


In [None]:
# Evaluate the model

from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

y_pred = model.predict(X_test)
y_pred_cls = np.argmax(y_pred, axis=1)

print(classification_report(y_test, y_pred_cls))
sns.heatmap(confusion_matrix(y_test, y_pred_cls), annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()
