In [1]:
# ================================================================
# CBSAtt Implementation - Exact Paper Replication
# Paper: "CBSAtt: a CNN-BiLSTM network with multi-head self-attention 
#         for EEG emotion recognition"
# ================================================================

import os, re, gc, time
import numpy as np
import pickle
from scipy import signal as scipy_signal
from scipy.stats import zscore
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import Callback
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# ================================================================
# Setup rclone for Google Drive
# ================================================================
print("=" * 70)
print("SETTING UP RCLONE FOR GOOGLE DRIVE")
print("=" * 70)

!curl https://rclone.org/install.sh | bash

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
rclone_config = user_secrets.get_secret("RCLONE_CONFIG")

if '\n' not in rclone_config:
    rclone_config = re.sub(r'\[(.*?)\]', r'[\1]\n', rclone_config)
    for key in ['type', 'scope', 'token', 'team_drive', 'client_id', 'client_secret', 'project_id']:
        rclone_config = rclone_config.replace(f' {key} =', f'\n{key} =')

os.makedirs('/root/.config/rclone', exist_ok=True)
with open('/root/.config/rclone/rclone.conf', 'w') as f:
    f.write(rclone_config)

print("rclone configured successfully!")
print("\nTesting connection...")
!rclone lsd gdrive: 2>&1
!rclone mkdir gdrive:/deap_cbsatt_models 2>&1
print("Ready to save to Google Drive: gdrive:/deap_cbsatt_models/\n")

# ================================================================
# Load DEAP Dataset
# ================================================================
print("=" * 70)
print("LOADING DEAP DATASET")
print("=" * 70)

BASE_PATH = "/kaggle/input/deap-dataset"

data_path = None
for root, dirs, files in os.walk(BASE_PATH):
    if 'data_preprocessed_python' in root:
        data_path = root
        break

if data_path is None:
    possible_paths = [
        os.path.join(BASE_PATH, "data_preprocessed_python"),
        os.path.join(BASE_PATH, "deap-dataset", "data_preprocessed_python"),
        BASE_PATH
    ]
    for p in possible_paths:
        if os.path.exists(p):
            dat_files = [f for f in os.listdir(p) if f.endswith('.dat')]
            if dat_files:
                data_path = p
                break

if data_path is None:
    raise FileNotFoundError("Could not find 'data_preprocessed_python' directory!")

print(f"Data path found: {data_path}")

def load_deap_data(data_path):
    if not os.path.exists(data_path):
        raise ValueError("Invalid data path!")
    
    dat_files = sorted([f for f in os.listdir(data_path) if f.endswith('.dat')])
    all_data, all_labels = [], []
    
    print("\nLoading participant data...")
    for i, filename in enumerate(dat_files):
        filepath = os.path.join(data_path, filename)
        try:
            with open(filepath, 'rb') as f:
                subject_data = pickle.load(f, encoding='latin1')
            all_data.append(subject_data['data'])
            all_labels.append(subject_data['labels'])
            
            if i == 0:
                print(f"  Data shape: {subject_data['data'].shape}")
                print(f"  Labels shape: {subject_data['labels'].shape}")
        except Exception as e:
            print(f"Error loading {filename}: {e}")
            continue
    
    all_data = np.array(all_data)
    all_labels = np.array(all_labels)
    return all_data, all_labels

eeg_data, emotion_labels = load_deap_data(data_path)

print("\nDataset Loaded Successfully")
print(f"  EEG data shape: {eeg_data.shape}")
print(f"  Labels shape: {emotion_labels.shape}")

# ================================================================
# Channel Selection (Paper: 16 channels as shown in Fig. 1)
# ================================================================
# Paper channels: Fp1, Fp2, F3, F4, F7, F8, FC5, FC6, T7, T8, P7, P8, O1, O2, AF3, AF4
# DEAP channel mapping (32 channels, 0-indexed)
SELECTED_CHANNELS = [0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31]  
# Assumption: selecting first 16 most relevant channels based on standard 10-20 system

print(f"\nChannel Selection: {len(SELECTED_CHANNELS)} channels selected")

# ================================================================
# Preprocessing (Paper: Sec 3.1)
# 1. Bandpass filter 4-45 Hz
# 2. Z-score normalization
# 3. 6s time window with 50% overlap
# ================================================================

def bandpass_filter(data, lowcut=4, highcut=45, fs=128, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = scipy_signal.butter(order, [low, high], btype='band')
    return scipy_signal.filtfilt(b, a, data, axis=-1)

def preprocess_deap(eeg_data, emotion_labels, selected_channels, window_size=6, overlap=0.5, fs=128):
    n_subjects, n_trials, n_channels, n_samples = eeg_data.shape
    
    # Select channels
    eeg_data = eeg_data[:, :, selected_channels, :]
    
    # Bandpass filter
    print("Applying bandpass filter (4-45 Hz)...")
    for subj in range(n_subjects):
        for trial in range(n_trials):
            eeg_data[subj, trial] = bandpass_filter(eeg_data[subj, trial], fs=fs)
    
    # Z-score normalization
    print("Applying Z-score normalization...")
    for subj in range(n_subjects):
        for trial in range(n_trials):
            for ch in range(len(selected_channels)):
                eeg_data[subj, trial, ch] = zscore(eeg_data[subj, trial, ch])
    
    # Sliding window
    print(f"Applying sliding window (window={window_size}s, overlap={overlap*100}%)...")
    window_samples = int(window_size * fs)
    step_samples = int(window_samples * (1 - overlap))
    
    windowed_data = []
    windowed_labels = []
    
    for subj in range(n_subjects):
        for trial in range(n_trials):
            trial_data = eeg_data[subj, trial]  # (channels, samples)
            n_windows = (trial_data.shape[1] - window_samples) // step_samples + 1
            
            for w in range(n_windows):
                start = w * step_samples
                end = start + window_samples
                window = trial_data[:, start:end]
                windowed_data.append(window)
                windowed_labels.append(emotion_labels[subj, trial])
    
    return np.array(windowed_data), np.array(windowed_labels)

preprocessed_data, preprocessed_labels = preprocess_deap(
    eeg_data, emotion_labels, SELECTED_CHANNELS
)

print(f"\nPreprocessed data shape: {preprocessed_data.shape}")
print(f"Preprocessed labels shape: {preprocessed_labels.shape}")

# ================================================================
# STFT Transformation (Paper: Sec 2.1, Eq. 1)
# ================================================================

def compute_stft(data, fs=128, nperseg=64, noverlap=32):
    # Assumption: using standard STFT parameters
    n_samples, n_channels, n_timepoints = data.shape
    stft_data = []
    
    print("Computing STFT for all samples...")
    for sample in range(n_samples):
        sample_stft = []
        for ch in range(n_channels):
            f, t, Zxx = scipy_signal.stft(data[sample, ch], fs=fs, nperseg=nperseg, noverlap=noverlap)
            # Use magnitude only
            sample_stft.append(np.abs(Zxx))
        stft_data.append(np.array(sample_stft))
    
    stft_data = np.array(stft_data)
    print(f"STFT data shape: {stft_data.shape}")
    return stft_data

stft_data = compute_stft(preprocessed_data)

# Create binary labels (Paper: median split at value 5)
valence_labels = (preprocessed_labels[:, 0] >= 5).astype(int)
arousal_labels = (preprocessed_labels[:, 1] >= 5).astype(int)

print(f"\nValence distribution: Low={np.sum(valence_labels==0)}, High={np.sum(valence_labels==1)}")
print(f"Arousal distribution: Low={np.sum(arousal_labels==0)}, High={np.sum(arousal_labels==1)}")

# ================================================================
# CBSAtt Model (Paper: Fig. 2, Fig. 3, Fig. 4, Fig. 5)
# ================================================================

def build_channel_cnn(input_shape):
    # Paper: 3x3 conv, 2x2 maxpool with stride 2 (Fig. 3)
    # Assumption: using 2 conv layers as shown in figure
    inputs = layers.Input(shape=input_shape)
    
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = layers.MaxPooling2D((2, 2), strides=2)(x)
    
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = layers.MaxPooling2D((2, 2), strides=2)(x)
    
    x = layers.Flatten()(x)
    
    return Model(inputs=inputs, outputs=x, name='channel_cnn')

def build_cbsatt(n_channels, freq_bins, time_bins, num_heads=4, lstm_units=128):
    # Input: (batch, channels, freq, time)
    inputs = layers.Input(shape=(n_channels, freq_bins, time_bins))
    
    # Channel-independent CNN (Paper: Sec 3.2)
    channel_outputs = []
    for i in range(n_channels):
        channel_input = layers.Lambda(lambda x: x[:, i:i+1, :, :])(inputs)
        channel_input = layers.Reshape((freq_bins, time_bins, 1))(channel_input)
        
        cnn = build_channel_cnn((freq_bins, time_bins, 1))
        channel_feat = cnn(channel_input)
        channel_outputs.append(channel_feat)
    
    # Concatenate all channel features
    x = layers.Concatenate()(channel_outputs)
    
    # Reshape for BiLSTM
    feature_dim = channel_outputs[0].shape[-1]
    x = layers.Reshape((n_channels, feature_dim))(x)
    
    # BiLSTM (Paper: Sec 3.3, Fig. 4, Eq. 2-8)
    x = layers.Bidirectional(layers.LSTM(lstm_units, return_sequences=True))(x)
    
    # Multi-Head Self-Attention (Paper: Sec 3.4, Fig. 5, Eq. 9-11)
    x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=lstm_units)(x, x)
    
    # Global pooling
    x = layers.GlobalAveragePooling1D()(x)
    
    # Classification head
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=inputs, outputs=outputs, name='CBSAtt')
    return model

SETTING UP RCLONE FOR GOOGLE DRIVE
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  4734  100  4734    0     0  10335      0 --:--:-- --:--:-- --:--:-- 10336
Archive:  rclone-current-linux-amd64.zip
   creating: tmp_unzip_dir_for_rclone/rclone-v1.71.2-linux-amd64/
  inflating: tmp_unzip_dir_for_rclone/rclone-v1.71.2-linux-amd64/README.txt  [text]  
  inflating: tmp_unzip_dir_for_rclone/rclone-v1.71.2-linux-amd64/README.html  [text]  
  inflating: tmp_unzip_dir_for_rclone/rclone-v1.71.2-linux-amd64/rclone  [binary]
  inflating: tmp_unzip_dir_for_rclone/rclone-v1.71.2-linux-amd64/rclone.1  [text]  
  inflating: tmp_unzip_dir_for_rclone/rclone-v1.71.2-linux-amd64/git-log.txt  [text]  
Purging old database entries in /usr/share/man...
Processing manual pages under /usr/share/man...
Purging old database entries in /usr/share/man/pl...
Processing manual pages under /usr/share/man

In [5]:
# ================================================================
# Channel Selection (Paper: 16 channels as shown in Fig. 1)
# ================================================================
# Paper channels: Fp1, Fp2, F3, F4, F7, F8, FC5, FC6, T7, T8, P7, P8, O1, O2, AF3, AF4
# DEAP channel mapping (32 channels, 0-indexed)
SELECTED_CHANNELS = [0, 1, 2, 3, 4, 7, 10, 13, 16, 17, 19, 20, 21, 25, 28, 31]  

print(f"\nChannel Selection: {len(SELECTED_CHANNELS)} channels selected")

# ================================================================
# Preprocessing (Paper: Sec 3.1)
# 1. Bandpass filter 4-45 Hz
# 2. Z-score normalization
# 3. 6s time window with 50% overlap
# ================================================================

def bandpass_filter(data, lowcut=4, highcut=45, fs=128, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = scipy_signal.butter(order, [low, high], btype='band')
    return scipy_signal.filtfilt(b, a, data, axis=-1)

def preprocess_deap(eeg_data, emotion_labels, selected_channels, window_size=6, overlap=0.5, fs=128):
    n_subjects, n_trials, n_channels, n_samples = eeg_data.shape
    
    # Select channels
    eeg_data = eeg_data[:, :, selected_channels, :]
    
    # Bandpass filter
    print("Applying bandpass filter (4-45 Hz)...")
    for subj in range(n_subjects):
        for trial in range(n_trials):
            eeg_data[subj, trial] = bandpass_filter(eeg_data[subj, trial], fs=fs)
    
    # Z-score normalization
    print("Applying Z-score normalization...")
    for subj in range(n_subjects):
        for trial in range(n_trials):
            for ch in range(len(selected_channels)):
                eeg_data[subj, trial, ch] = zscore(eeg_data[subj, trial, ch])
    
    # Sliding window
    print(f"Applying sliding window (window={window_size}s, overlap={overlap*100}%)...")
    window_samples = int(window_size * fs)
    step_samples = int(window_samples * (1 - overlap))
    
    windowed_data = []
    windowed_labels = []
    
    for subj in range(n_subjects):
        for trial in range(n_trials):
            trial_data = eeg_data[subj, trial]  # (channels, samples)
            n_windows = (trial_data.shape[1] - window_samples) // step_samples + 1
            
            for w in range(n_windows):
                start = w * step_samples
                end = start + window_samples
                window = trial_data[:, start:end]
                windowed_data.append(window)
                windowed_labels.append(emotion_labels[subj, trial])
    
    return np.array(windowed_data), np.array(windowed_labels)

preprocessed_data, preprocessed_labels = preprocess_deap(
    eeg_data, emotion_labels, SELECTED_CHANNELS
)

print(f"\nPreprocessed data shape: {preprocessed_data.shape}")
print(f"Preprocessed labels shape: {preprocessed_labels.shape}")

# ================================================================
# STFT Transformation (Paper: Sec 2.1, Eq. 1)
# ================================================================

def compute_stft(data, fs=128, nperseg=64, noverlap=32):
    # Assumption: using standard STFT parameters
    n_samples, n_channels, n_timepoints = data.shape
    stft_data = []
    
    print("Computing STFT for all samples...")
    for sample in range(n_samples):
        sample_stft = []
        for ch in range(n_channels):
            f, t, Zxx = scipy_signal.stft(data[sample, ch], fs=fs, nperseg=nperseg, noverlap=noverlap)
            # Use magnitude only
            sample_stft.append(np.abs(Zxx))
        stft_data.append(np.array(sample_stft))
    
    stft_data = np.array(stft_data)
    print(f"STFT data shape: {stft_data.shape}")
    return stft_data

stft_data = compute_stft(preprocessed_data)

# Create binary labels (Paper: median split at value 5)
valence_labels = (preprocessed_labels[:, 0] >= 5).astype(int)
arousal_labels = (preprocessed_labels[:, 1] >= 5).astype(int)

print(f"\nValence distribution: Low={np.sum(valence_labels==0)}, High={np.sum(valence_labels==1)}")
print(f"Arousal distribution: Low={np.sum(arousal_labels==0)}, High={np.sum(arousal_labels==1)}")

# ================================================================
# CBSAtt Model (Paper: Fig. 2, Fig. 3, Fig. 4, Fig. 5)
# ================================================================


def build_cbsatt(n_channels, freq_bins, time_bins, num_heads=4, lstm_units=128):
    inputs = layers.Input(shape=(n_channels, freq_bins, time_bins))
    
    # Channel-independent CNN (Paper: Sec 3.2)
    channel_outputs = []
    for i in range(n_channels):
        channel_input = layers.Lambda(lambda x, idx=i: x[:, idx:idx+1, :, :])(inputs)
        channel_input = layers.Reshape((freq_bins, time_bins, 1))(channel_input)
        
        # Paper: 3x3 conv, 2x2 maxpool with stride 2 (Fig. 3)
        x = layers.Conv2D(32, (3, 3), activation='relu', padding='same', name=f'conv1_ch{i}')(channel_input)
        x = layers.MaxPooling2D((2, 2), strides=2, name=f'pool1_ch{i}')(x)
        
        x = layers.Conv2D(64, (3, 3), activation='relu', padding='same', name=f'conv2_ch{i}')(x)
        x = layers.MaxPooling2D((2, 2), strides=2, name=f'pool2_ch{i}')(x)
        
        x = layers.Flatten(name=f'flatten_ch{i}')(x)
        channel_outputs.append(x)
    
    x = layers.Concatenate(name='concat_channels')(channel_outputs)
    
    # Concatenate all channel features
    x = layers.Concatenate()(channel_outputs)
    
    # Reshape for BiLSTM
    feature_dim = channel_outputs[0].shape[-1]
    x = layers.Reshape((n_channels, feature_dim))(x)
    
    # BiLSTM (Paper: Sec 3.3, Fig. 4, Eq. 2-8)
    x = layers.Bidirectional(layers.LSTM(lstm_units, return_sequences=True))(x)
    
    # Multi-Head Self-Attention (Paper: Sec 3.4, Fig. 5, Eq. 9-11)
    x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=lstm_units)(x, x)
    
    # Global pooling
    x = layers.GlobalAveragePooling1D()(x)
    
    # Classification head
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=inputs, outputs=outputs, name='CBSAtt')
    return model


Channel Selection: 16 channels selected
Applying bandpass filter (4-45 Hz)...
Applying Z-score normalization...
Applying sliding window (window=6s, overlap=50.0%)...

Preprocessed data shape: (25600, 16, 768)
Preprocessed labels shape: (25600, 4)
Computing STFT for all samples...
STFT data shape: (25600, 16, 33, 25)

Valence distribution: Low=19900, High=5700
Arousal distribution: Low=19300, High=6300


In [6]:
# ================================================================
# Training Configuration (Paper: Table 3)
# ================================================================

LEARNING_RATE = 0.001
EPOCHS = 30
BATCH_SIZE = 128
NUM_HEADS = 4
LSTM_UNITS = 128

# Build models
n_channels = stft_data.shape[1]
freq_bins = stft_data.shape[2]
time_bins = stft_data.shape[3]

print("\n" + "=" * 70)
print("BUILDING CBSATT MODEL")
print("=" * 70)
print(f"Input shape: ({n_channels}, {freq_bins}, {time_bins})")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Number of attention heads: {NUM_HEADS}")
print(f"BiLSTM hidden units: {LSTM_UNITS}")

valence_model = build_cbsatt(n_channels, freq_bins, time_bins, NUM_HEADS, LSTM_UNITS)
arousal_model = build_cbsatt(n_channels, freq_bins, time_bins, NUM_HEADS, LSTM_UNITS)

valence_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

arousal_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

print("\nModel architecture:")
valence_model.summary()

# ================================================================
# Custom Callback for Google Drive Saving
# ================================================================

class RcloneSaveCallback(Callback):
    def __init__(self, model_name, save_every=3):
        super().__init__()
        self.model_name = model_name
        self.save_every = save_every
    
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.save_every == 0:
            filepath = f'/kaggle/working/{self.model_name}_epoch_{epoch+1}.h5'
            self.model.save(filepath)
            
            remote_path = f'gdrive:/deap_cbsatt_models/{self.model_name}_epoch_{epoch+1}.h5'
            os.system(f'rclone copy {filepath} gdrive:/deap_cbsatt_models/ 2>&1')
            print(f"\nModel saved to Google Drive: {remote_path}")

# ================================================================
# Training 
# ================================================================

X_train_val, X_test_val, y_train_val, y_test_val = train_test_split(
    stft_data, valence_labels, test_size=0.2, random_state=42, stratify=valence_labels
)
X_train_ar, X_test_ar, y_train_ar, y_test_ar = train_test_split(
    stft_data, arousal_labels, test_size=0.2, random_state=42, stratify=arousal_labels
)


BUILDING CBSATT MODEL
Input shape: (16, 33, 25)
Learning rate: 0.001
Batch size: 128
Epochs: 30
Number of attention heads: 4
BiLSTM hidden units: 128

Model architecture:


In [7]:
print("\n" + "=" * 70)
print("TRAINING VALENCE MODEL")
print("=" * 70)
print(f"Training samples: {len(X_train_val)}")
print(f"Test samples: {len(X_test_val)}")

valence_start = time.time()

history_val = valence_model.fit(
    X_train_val, y_train_val,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_split=0.2,
    callbacks=[RcloneSaveCallback('valence_model', save_every=3)],
    verbose=1
)

valence_train_time = time.time() - valence_start


TRAINING VALENCE MODEL
Training samples: 20480
Test samples: 5120
Epoch 1/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 84ms/step - accuracy: 0.7613 - loss: 0.5578 - val_accuracy: 0.7859 - val_loss: 0.5208
Epoch 2/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 72ms/step - accuracy: 0.7779 - loss: 0.5336 - val_accuracy: 0.7861 - val_loss: 0.5049
Epoch 3/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.7730 - loss: 0.5290




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_3.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 99ms/step - accuracy: 0.7730 - loss: 0.5290 - val_accuracy: 0.7859 - val_loss: 0.5050
Epoch 4/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.7673 - loss: 0.5250 - val_accuracy: 0.7866 - val_loss: 0.4933
Epoch 5/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 74ms/step - accuracy: 0.7731 - loss: 0.5112 - val_accuracy: 0.7874 - val_loss: 0.4825
Epoch 6/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 69ms/step - accuracy: 0.7808 - loss: 0.4952




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_6.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 100ms/step - accuracy: 0.7807 - loss: 0.4952 - val_accuracy: 0.7859 - val_loss: 0.4821
Epoch 7/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.7776 - loss: 0.4914 - val_accuracy: 0.7869 - val_loss: 0.4924
Epoch 8/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.7817 - loss: 0.4732 - val_accuracy: 0.7859 - val_loss: 0.4933
Epoch 9/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.7824 - loss: 0.4672




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_9.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 99ms/step - accuracy: 0.7824 - loss: 0.4672 - val_accuracy: 0.7874 - val_loss: 0.4831
Epoch 10/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.7848 - loss: 0.4578 - val_accuracy: 0.7866 - val_loss: 0.6129
Epoch 11/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.7979 - loss: 0.4369 - val_accuracy: 0.7231 - val_loss: 0.5391
Epoch 12/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.8266 - loss: 0.3959




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_12.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 100ms/step - accuracy: 0.8266 - loss: 0.3959 - val_accuracy: 0.7705 - val_loss: 0.5036
Epoch 13/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 74ms/step - accuracy: 0.8467 - loss: 0.3445 - val_accuracy: 0.7747 - val_loss: 0.5428
Epoch 14/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.8725 - loss: 0.2976 - val_accuracy: 0.7483 - val_loss: 0.5869
Epoch 15/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.9000 - loss: 0.2421




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_15.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 101ms/step - accuracy: 0.8999 - loss: 0.2422 - val_accuracy: 0.7905 - val_loss: 0.6537
Epoch 16/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.9067 - loss: 0.2257 - val_accuracy: 0.7830 - val_loss: 0.7654
Epoch 17/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.9272 - loss: 0.1764 - val_accuracy: 0.7603 - val_loss: 0.6543
Epoch 18/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.9505 - loss: 0.1345




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_18.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 97ms/step - accuracy: 0.9505 - loss: 0.1346 - val_accuracy: 0.7961 - val_loss: 1.1607
Epoch 19/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 74ms/step - accuracy: 0.9567 - loss: 0.1147 - val_accuracy: 0.7808 - val_loss: 0.9213
Epoch 20/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.9647 - loss: 0.0970 - val_accuracy: 0.7864 - val_loss: 0.9459
Epoch 21/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.9740 - loss: 0.0750




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_21.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 99ms/step - accuracy: 0.9740 - loss: 0.0751 - val_accuracy: 0.7834 - val_loss: 1.4529
Epoch 22/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.9727 - loss: 0.0746 - val_accuracy: 0.7529 - val_loss: 1.0823
Epoch 23/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.9695 - loss: 0.0765 - val_accuracy: 0.7986 - val_loss: 1.5665
Epoch 24/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.9722 - loss: 0.0732




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_24.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 102ms/step - accuracy: 0.9722 - loss: 0.0731 - val_accuracy: 0.7773 - val_loss: 1.1224
Epoch 25/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.9834 - loss: 0.0459 - val_accuracy: 0.7837 - val_loss: 1.4975
Epoch 26/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.9869 - loss: 0.0373 - val_accuracy: 0.7734 - val_loss: 1.3399
Epoch 27/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.9855 - loss: 0.0424




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_27.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 98ms/step - accuracy: 0.9855 - loss: 0.0424 - val_accuracy: 0.7874 - val_loss: 1.7720
Epoch 28/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 74ms/step - accuracy: 0.9890 - loss: 0.0308 - val_accuracy: 0.7803 - val_loss: 1.6077
Epoch 29/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 73ms/step - accuracy: 0.9878 - loss: 0.0355 - val_accuracy: 0.7505 - val_loss: 1.8099
Epoch 30/30
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.9887 - loss: 0.0305




Model saved to Google Drive: gdrive:/deap_cbsatt_models/valence_model_epoch_30.h5
[1m128/128[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 101ms/step - accuracy: 0.9887 - loss: 0.0305 - val_accuracy: 0.7832 - val_loss: 1.6285


In [None]:
print("\n" + "=" * 70)
print("TRAINING AROUSAL MODEL")
print("=" * 70)
print(f"Training samples: {len(X_train_ar)}")
print(f"Test samples: {len(X_test_ar)}")

arousal_start = time.time()

history_ar = arousal_model.fit(
    X_train_ar, y_train_ar,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_split=0.2,
    callbacks=[RcloneSaveCallback('arousal_model', save_every=3)],
    verbose=1
)

arousal_train_time = time.time() - arousal_start

In [None]:

# ================================================================
# Evaluation (Paper: Fig. 6, Fig. 7, Fig. 8)
# ================================================================

print("\n" + "=" * 70)
print("EVALUATION RESULTS")
print("=" * 70)

y_pred_val = (valence_model.predict(X_test_val) > 0.5).astype(int).flatten()
y_pred_ar = (arousal_model.predict(X_test_ar) > 0.5).astype(int).flatten()

val_acc = accuracy_score(y_test_val, y_pred_val)
ar_acc = accuracy_score(y_test_ar, y_pred_ar)

print(f"\nValence Classification Accuracy: {val_acc*100:.2f}%")
print(f"Arousal Classification Accuracy: {ar_acc*100:.2f}%")
print(f"Average Accuracy: {(val_acc + ar_acc)/2 * 100:.2f}%")

print(f"\nValence Training Time: {valence_train_time:.2f} seconds")
print(f"Arousal Training Time: {arousal_train_time:.2f} seconds")
print(f"Total Training Time: {valence_train_time + arousal_train_time:.2f} seconds")

# Confusion matrices
cm_val = confusion_matrix(y_test_val, y_pred_val)
cm_ar = confusion_matrix(y_test_ar, y_pred_ar)

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues', ax=axes[0])
axes[0].set_title('Valence Confusion Matrix')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('Actual')

sns.heatmap(cm_ar, annot=True, fmt='d', cmap='Greens', ax=axes[1])
axes[1].set_title('Arousal Confusion Matrix')
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('Actual')

plt.tight_layout()
plt.savefig('/kaggle/working/confusion_matrices.png', dpi=300, bbox_inches='tight')
print("\nConfusion matrices saved to: /kaggle/working/confusion_matrices.png")

# Save final models
valence_model.save('/kaggle/working/valence_model_final.h5')
arousal_model.save('/kaggle/working/arousal_model_final.h5')

os.system('rclone copy /kaggle/working/valence_model_final.h5 gdrive:/deap_cbsatt_models/ 2>&1')
os.system('rclone copy /kaggle/working/arousal_model_final.h5 gdrive:/deap_cbsatt_models/ 2>&1')
os.system('rclone copy /kaggle/working/confusion_matrices.png gdrive:/deap_cbsatt_models/ 2>&1')

print("\nFinal models and results saved to Google Drive!")

# Save performance metrics
with open('/kaggle/working/performance_metrics.txt', 'w') as f:
    f.write("CBSAtt Performance Metrics\n")
    f.write("=" * 50 + "\n\n")
    f.write(f"Valence Accuracy: {val_acc*100:.2f}%\n")
    f.write(f"Arousal Accuracy: {ar_acc*100:.2f}%\n")
    f.write(f"Average Accuracy: {(val_acc + ar_acc)/2 * 100:.2f}%\n\n")
    f.write(f"Valence Training Time: {valence_train_time:.2f} seconds\n")
    f.write(f"Arousal Training Time: {arousal_train_time:.2f} seconds\n")
    f.write(f"Total Training Time: {valence_train_time + arousal_train_time:.2f} seconds\n\n")
    f.write("Hyperparameters:\n")
    f.write(f"  Learning Rate: {LEARNING_RATE}\n")
    f.write(f"  Batch Size: {BATCH_SIZE}\n")
    f.write(f"  Epochs: {EPOCHS}\n")
    f.write(f"  Attention Heads: {NUM_HEADS}\n")
    f.write(f"  LSTM Units: {LSTM_UNITS}\n")

os.system('rclone copy /kaggle/working/performance_metrics.txt gdrive:/deap_cbsatt_models/ 2>&1')

print("\n" + "=" * 70)
print("TRAINING COMPLETE")
print("=" * 70)