### Imports & GPU Configuration

In [1]:
# ==============================================================================
# 0. Importing Dependencies
# ==============================================================================
import os
import glob
import numpy as np
import pandas as pd
import tensorflow as tf
import pywt
import datetime
import json
from typing import List, Tuple, Dict, Optional

# Keras and TensorFlow Layers for the CLISA Model
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D, SeparableConv2D, BatchNormalization,
    AveragePooling2D, Bidirectional, LSTM, Dropout, Dense, concatenate,
    Layer, GlobalAveragePooling1D, Reshape
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard
from tensorflow.keras.constraints import max_norm

# Scikit-learn for data splitting and class weights
try:
    from sklearn.model_selection import train_test_split
    from sklearn.utils.class_weight import compute_class_weight
    _HAVE_SKLEARN = True
except ImportError:
    _HAVE_SKLEARN = False

# ==============================================================================
# 1. GPU Memory Configuration
# ==============================================================================
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"GPU memory growth configured for {len(gpus)} device(s).")
    except RuntimeError as e:
        print(e)

2025-09-22 06:11:21.488795: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


GPU memory growth configured for 1 device(s).


### Global Constants & Configuration

In [2]:
# ==============================================================================
# 2. Configuration and Constants
# ==============================================================================
CLASS_NAMES = [
    'happy', 'sad', 'surprised', 'satisfied',
    'protected', 'frightened', 'angry', 'unconcerned'
]

FOLDER_TO_CLASS = {
    'Happy': 'happy',
    'Sad': 'sad',
    'Surprise': 'surprised',
    'Satisfied': 'satisfied',
    'Protected': 'protected',
    'Frightened': 'frightened',
    'Angry': 'angry',
    'Unconcerned': 'unconcerned'
}

# The model from the paper was designed for 62 channels, but we can adapt it.
# We will focus on the 14 core EEG channels for a robust signal.
EEG_CHANNELS = [
    'EEG.AF3', 'EEG.F7', 'EEG.F3', 'EEG.FC5', 'EEG.T7', 'EEG.P7', 'EEG.O1',
    'EEG.O2', 'EEG.P8', 'EEG.T8', 'EEG.FC6', 'EEG.F4', 'EEG.F8', 'EEG.AF4'
]

### Data Loading (Adapted for CLISA)

In [3]:
# ==============================================================================
# 3. Data Loading and Preprocessing Functions (Adapted for CLISA)
# ==============================================================================

# Functions: wavelet_denoise and _normalize_per_sample remain the same.

def wavelet_denoise(data, wavelet='db4', level=4):
    """Applies wavelet denoising to a 1D signal."""
    coeffs = pywt.wavedec(data, wavelet, level=level)
    sigma = np.median(np.abs(coeffs[-1] - np.median(coeffs[-1]))) / 0.6745
    threshold = sigma * np.sqrt(2 * np.log(len(data)))
    new_coeffs = coeffs.copy()
    for i in range(1, len(coeffs)):
        new_coeffs[i] = pywt.threshold(coeffs[i], value=threshold, mode='soft')
    reconstructed_signal = pywt.waverec(new_coeffs, wavelet)
    return reconstructed_signal[:len(data)]

def _read_and_denoise_csv(filepath: str) -> np.ndarray:
    """Reads a CSV and returns a (time_points, channels) array of the core EEG signals."""
    df = pd.read_csv(filepath, skiprows=1, low_memory=False)
    denoised_data = {}
    for channel in EEG_CHANNELS:
        if channel in df.columns:
            signal = df[channel].dropna().values
            denoised_data[channel] = wavelet_denoise(signal) if np.var(signal) > 0 else signal
    
    max_len = max(len(v) for v in denoised_data.values()) if denoised_data else 0
    for channel, signal in denoised_data.items():
        if len(signal) < max_len:
            padding = np.zeros(max_len - len(signal))
            denoised_data[channel] = np.concatenate([signal, padding])
            
    return pd.DataFrame(denoised_data).values.astype(np.float32)

def _ensure_shape_and_pad(raw: np.ndarray, channels: int, time_steps: int) -> np.ndarray:
    """Ensure data has shape (channels, time_steps) by padding/truncating and transposing."""
    # Transpose to (channels, time_points)
    data = raw.T
    
    if data.shape[1] > time_steps:
        data = data[:, :time_steps]
    elif data.shape[1] < time_steps:
        padding = np.zeros((channels, time_steps - data.shape[1]), dtype=data.dtype)
        data = np.concatenate([data, padding], axis=1)
    return data

def _normalize_per_sample(sample: np.ndarray) -> np.ndarray:
    """Normalize each sample."""
    mean = sample.mean()
    std = sample.std()
    return (sample - mean) / (std + 1e-8)

def load_eeg_dataset(
    data_dir: str, channels: int, time_steps: int,
    stressed_classes: Optional[List[str]] = None, test_size: float = 0.15,
    val_size: float = 0.15, random_state: int = 42, batch_size: int = 4
) -> Tuple[Dict[str, tf.data.Dataset], Dict]:
    """Load EEG dataset, preparing it for the CLISA model architecture."""
    # ... (The core logic of finding files and splitting data remains the same) ...
    if not _HAVE_SKLEARN:
        raise ImportError("Scikit-learn is required.")

    if stressed_classes is None: stressed_classes = ['frightened', 'angry']
    
    files, labels = [], []
    for folder, cls in FOLDER_TO_CLASS.items():
        cls_folder = os.path.join(data_dir, folder)
        if os.path.isdir(cls_folder):
            found = glob.glob(os.path.join(cls_folder, "*.csv"))
            files.extend(found)
            labels.extend([cls] * len(found))

    if not files: raise ValueError(f"No CSV files found in {data_dir}.")

    X_list, y_multi_idx, y_binary = [], [], []
    for fpath, cls in zip(files, labels):
        raw = _read_and_denoise_csv(fpath)
        if raw.shape[1] != channels: continue # Skip if channel count is wrong
            
        sample = _ensure_shape_and_pad(raw, channels, time_steps)
        sample = _normalize_per_sample(sample)
        # Add a final dimension for the CNN: (channels, time_points, 1)
        X_list.append(sample[..., np.newaxis].astype(np.float32)) 
        
        y_multi_idx.append(CLASS_NAMES.index(cls))
        y_binary.append(1 if cls in stressed_classes else 0)

    X = np.stack(X_list, axis=0)
    # ... (The rest of the function is the same as the last working version) ...
    y_multi_idx = np.array(y_multi_idx, dtype=np.int32)
    y_binary = np.array(y_binary, dtype=np.float32)
    y_multi_onehot = tf.keras.utils.to_categorical(y_multi_idx, num_classes=len(CLASS_NAMES))
    
    weights = compute_class_weight('balanced', classes=np.arange(len(CLASS_NAMES)), y=y_multi_idx)
    class_weights = {i: float(w) for i, w in enumerate(weights)}
    
    emotion_sample_weights = np.array([class_weights[label] for label in y_multi_idx], dtype=np.float32)
    stress_sample_weights = np.ones_like(y_binary, dtype=np.float32)
    
    indices = np.arange(len(X))
    train_indices, temp_indices = train_test_split(indices, test_size=(test_size + val_size), random_state=random_state, stratify=y_multi_idx)
    val_indices, test_indices = train_test_split(temp_indices, test_size=(test_size / (test_size + val_size)), random_state=random_state, stratify=y_multi_idx[temp_indices])

    def make_ds(inds):
        x = X[inds]
        y = {'stressed_not_stressed_output': y_binary[inds], 'emotion_class_output': y_multi_onehot[inds]}
        sw = {'stressed_not_stressed_output': stress_sample_weights[inds], 'emotion_class_output': emotion_sample_weights[inds]}
        return tf.data.Dataset.from_tensor_slices((x, y, sw)).shuffle(len(inds), seed=random_state).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    datasets = {'train': make_ds(train_indices), 'val': make_ds(val_indices), 'test': make_ds(test_indices)}
    meta = {'counts': {cls: labels.count(cls) for cls in CLASS_NAMES}, 'total_samples': len(X), 'class_weights': class_weights, 'index_to_class': {i: c for i, c in enumerate(CLASS_NAMES)}}
    return datasets, meta

### The CLISA Model Architecture

In [4]:
# ==============================================================================
# 5. The CLISA Model Architecture
# ==============================================================================

# The custom Attention layer remains the same
class Attention(Layer):
    # ... (Same as before)
    def __init__(self, **kwargs):
        super(Attention, self).__init__(**kwargs)
    def build(self, input_shape):
        self.W = self.add_weight(name="att_weight", shape=(input_shape[-1], 1), initializer="normal")
        self.b = self.add_weight(name="att_bias", shape=(input_shape[1], 1), initializer="zeros")
        super(Attention, self).build(input_shape)
    def call(self, x):
        et = tf.keras.backend.squeeze(tf.keras.backend.tanh(tf.keras.backend.dot(x, self.W) + self.b), axis=-1)
        at = tf.keras.backend.softmax(et)
        at = tf.keras.backend.expand_dims(at, axis=-1)
        output = x * at
        return tf.keras.backend.sum(output, axis=1)

def create_clisa_model(input_shape, num_classes=8, F1=8, D=2, F2=16, dropout_rate=0.25):
    """
    Creates the CLISA model architecture as described in the paper.
    F1: Number of temporal filters.
    D: Depth multiplier for spatial filters.
    F2: Number of pointwise filters.
    """
    C, T = input_shape[0], input_shape[1] # Channels and Timepoints
    
    input_layer = Input(shape=(C, T, 1))

    # --- BLOCK 1: Temporal and Spatial Convolutions ---
    # Temporal Convolution
    x = Conv2D(F1, (1, 64), padding='same', use_bias=False)(input_layer)
    x = BatchNormalization()(x)
    
    # Spatial Convolution (Depthwise)
    x = DepthwiseConv2D((C, 1), use_bias=False, depth_multiplier=D, depthwise_constraint=max_norm(1.))(x)
    x = BatchNormalization()(x)
    x = tf.keras.layers.Activation('elu')(x)
    x = AveragePooling2D((1, 4))(x)
    x = Dropout(dropout_rate)(x)

    # --- BLOCK 2: Separable Convolution ---
    # Separable convolution combines depthwise and pointwise convolutions
    x = SeparableConv2D(F2, (1, 16), use_bias=False, padding='same')(x)
    x = BatchNormalization()(x)
    x = tf.keras.layers.Activation('elu')(x)
    x = AveragePooling2D((1, 8))(x)
    x = Dropout(dropout_rate)(x)

    # --- Flatten and prepare for LSTM ---
    # Reshape the output to be (batch_size, time_steps, features)
    # The new number of time steps is T // 32 (due to pooling)
    new_time_steps = T // (4 * 8) 
    x = Reshape((new_time_steps, F2))(x)

    # --- Recurrent Layers ---
    x = Bidirectional(LSTM(64, return_sequences=True))(x)
    x = Dropout(0.3)(x)
    
    # --- Output Heads ---
    attention_output = Attention()(x)
    main_path = GlobalAveragePooling1D()(x)
    main_path = concatenate([main_path, attention_output])
    
    binary_head = Dense(32, activation='relu')(main_path)
    binary_head = Dropout(0.5)(binary_head)
    binary_head_output = Dense(1, activation='sigmoid', name='stressed_not_stressed_output')(binary_head)

    multiclass_head = Dense(32, activation='relu')(main_path)
    multiclass_head = Dropout(0.5)(multiclass_head)
    multiclass_head_output = Dense(num_classes, activation='softmax', name='emotion_class_output')(multiclass_head)

    model = Model(
        inputs=input_layer,
        outputs={
            "stressed_not_stressed_output": binary_head_output,
            "emotion_class_output": multiclass_head_output
        }
    )
    return model

### Main Execution Block

In [5]:
# ==============================================================================
# 6. Main Execution Block
# ==============================================================================
if __name__ == '__main__':
    # --- Path and Model Parameters ---
    dataset_path = "/media/kd/New Volume/Github/EEG-Emotion-Detection/dataset"
    
    # Parameters adapted for the CLISA architecture
    INPUT_TIME_STEPS = 512 # The paper uses 1 second of data at 200Hz, we adapt for 128Hz
    INPUT_CHANNELS = len(EEG_CHANNELS)
    INPUT_SHAPE = (INPUT_CHANNELS, INPUT_TIME_STEPS) # (Channels, Time)
    BATCH_SIZE = 8 # This model is efficient, so we can try a slightly larger batch size

    # --- Load Data ---
    print("--- Loading and Preprocessing Dataset ---")
    datasets, meta = load_eeg_dataset(
        data_dir=dataset_path,
        channels=INPUT_CHANNELS,
        time_steps=INPUT_TIME_STEPS,
        batch_size=BATCH_SIZE
    )

    # --- Build and Compile Model ---
    print("\n--- Building CLISA Model ---")
    model = create_clisa_model(INPUT_SHAPE, num_classes=len(CLASS_NAMES))
    optimizer = Adam(learning_rate=1e-3)

    model.compile(
        optimizer=optimizer,
        loss={'stressed_not_stressed_output': 'binary_crossentropy', 'emotion_class_output': 'categorical_crossentropy'},
        loss_weights={'stressed_not_stressed_output': 0.5, 'emotion_class_output': 1.0},
        metrics={'stressed_not_stressed_output': 'accuracy', 'emotion_class_output': 'accuracy'}
    )
    model.summary()

    # --- Define Callbacks ---
    os.makedirs('models', exist_ok=True)
    os.makedirs('logs', exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    
    callbacks = [
        ModelCheckpoint(
            filepath=f'models/best_clisa_model_{timestamp}.keras',
            monitor='val_emotion_class_output_accuracy', 
            save_best_only=True, mode='max', verbose=1
        ),
        EarlyStopping(
            monitor='val_emotion_class_output_accuracy', 
            patience=1000, # Give it more patience to learn
            restore_best_weights=True, mode='max', verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_emotion_class_output_accuracy', 
            factor=0.2, patience=8, min_lr=1e-6, mode='max', verbose=1
        ),
        TensorBoard(log_dir=f'logs/fit/{timestamp}', histogram_freq=1)
    ]

    # --- Train the Model ---
    print("\n--- Starting Model Training ---")
    history = model.fit(
        datasets['train'],
        validation_data=datasets['val'],
        epochs=300, # Train for more epochs
        callbacks=callbacks,
        verbose=1
    )

    # --- Final Evaluation and Saving ---
    # ... (This part of the code remains the same as your last working version) ...

--- Loading and Preprocessing Dataset ---


I0000 00:00:1758501700.869702  119952 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 797 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3070 Ti, pci bus id: 0000:01:00.0, compute capability: 8.6



--- Building CLISA Model ---



--- Starting Model Training ---
Epoch 1/300


E0000 00:00:1758501704.709671  119952 meta_optimizer.cc:967] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inStatefulPartitionedCall/functional_1/dropout_1/stateless_dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2025-09-22 06:11:45.017903: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91300


[1m19/21[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 9ms/step - emotion_class_output_accuracy: 0.1130 - emotion_class_output_loss: 2.1057 - loss: 2.4537 - stressed_not_stressed_output_accuracy: 0.5566 - stressed_not_stressed_output_loss: 0.6960
Epoch 1: val_emotion_class_output_accuracy improved from None to 0.08333, saving model to models/best_clisa_model_20250922-061142.keras
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 41ms/step - emotion_class_output_accuracy: 0.1024 - emotion_class_output_loss: 2.1131 - loss: 2.4496 - stressed_not_stressed_output_accuracy: 0.6386 - stressed_not_stressed_output_loss: 0.6705 - val_emotion_class_output_accuracy: 0.0833 - val_emotion_class_output_loss: 2.0835 - val_loss: 2.4128 - val_stressed_not_stressed_output_accuracy: 0.7500 - val_stressed_not_stressed_output_loss: 0.6653 - learning_rate: 0.0010
Epoch 2/300
[1m19/21[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 9ms/step - emotion_class_output_accuracy: 0.1630