In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from scipy.signal import spectrogram, iirnotch, cheby1, filtfilt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout
from scipy.optimize import minimize
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt


def fit_gauss(params, x, y):
    mu, sig, amp = params
    return np.sum((y - amp * np.exp(-0.5 * ((x - mu) / sig) ** 2)) ** 2)

def sixtyHzFilt_EEG(signal, fs, freq=60.0, Q=30.0):
    nyquist = 0.5 * fs
    notch_freq = freq / nyquist
    b, a = iirnotch(notch_freq, Q)
    return filtfilt(b, a, signal)

def highPassChebyshev1Filt_EEG(signal, fs, cutoff=0.5, rp=0.1, order=5):
    nyquist = 0.5 * fs
    low = cutoff / nyquist
    b, a = cheby1(order, rp, low, btype='high')
    return filtfilt(b, a, signal)

def normalizeEEG(signal, fs):

    # Apply filters
    filtSignal = sixtyHzFilt_EEG(signal, fs)
    filtSignal = highPassChebyshev1Filt_EEG(filtSignal, fs)

    # Remove flatline zeros
    tempSignal = filtSignal[filtSignal != 0]

    # Histogram and Gaussian fit
    hist, bin_edges = np.histogram(tempSignal, bins=int(np.sqrt(len(tempSignal))), density=True)
    x = (bin_edges[1:] + bin_edges[:-1]) / 2  # Bin centers
    guess = [np.mean(tempSignal), np.std(tempSignal), np.max(hist)]

    # Optimize Gaussian fit
    result = minimize(fit_gauss, guess, args=(x, hist))
    mu, sig, amp = result.x

    # Normalize the signal using the Gaussian parameters
    normSignal = (filtSignal - mu) / sig
    return normSignal

# Spectrogram Generation
def compute_spectrogram(signal, fs):
    nperseg = 64 
    noverlap = nperseg // 2
    f, t, Sxx = spectrogram(signal, fs, nperseg=nperseg, noverlap=noverlap)
    return Sxx.T

# CNN Model
def create_cnn_model(input_shape):
    model = Sequential([
        Conv1D(filters=64, kernel_size=1, activation='relu', input_shape=input_shape, padding='same'),  # Kernel size reduced
        MaxPooling1D(pool_size=2),
        Conv1D(filters=128, kernel_size=1, activation='relu', padding='same'),  # Kernel size reduced
        MaxPooling1D(pool_size=2),
        Flatten(),
        Dense(128, activation='relu'),
        Dropout(0.5),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model


# Load EEG data and labels
fs = 173.61  # Sampling frequency
eeg_data = pd.read_excel('data/bonn_shortSignal.xlsx', header=None).values  
labels = pd.read_excel('data/bonnLabels_shortSignal.xlsx', header=None).values.flatten()

# Preprocess the data
filtered_data = [highPassChebyshev1Filt_EEG(sixtyHzFilt_EEG(signal, fs), fs) for signal in eeg_data]
normalized_data = [normalizeEEG(signal, fs) for signal in filtered_data]
spectrograms = [compute_spectrogram(signal, fs) for signal in normalized_data]

# Convert to NumPy array and adjust dimensions for Conv1D
X = np.array(spectrograms)  
X = X[..., np.newaxis]
y = np.array(labels)

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.05, stratify=y, random_state=42)

# Create and train the model
input_shape = (X_train.shape[1], X_train.shape[2])
cnn_model = create_cnn_model(input_shape)

history = cnn_model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=100, batch_size=32)

# Evaluate the model
test_loss, test_accuracy = cnn_model.evaluate(X_test, y_test)
print(f"Test Accuracy: {test_accuracy:.2f}")

# Classification report
y_pred = (cnn_model.predict(X_test) > 0.5).astype(int)
print(classification_report(y_test, y_pred))

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

# After training the model, save it as an .h5 file
cnn_model.save('cnn_model.h5')