In [None]:
"""
EEG Person Identification - CNN + RNN Model Training
Hybrid architecture combining spatial-frequency and temporal features
Author: [Your Name]
Date: 2025

Model Architecture:
1. CNN Branch: Extracts spatial-frequency features from spectrograms
2. RNN Branch: Captures temporal dynamics from raw EEG
3. Fusion Layer: Combines both feature representations
4. Classification: 109-class person identification
"""

#%% Import Required Libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard
from tensorflow.keras.utils import to_categorical
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("Libraries imported successfully!")
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

#%% Configuration

class Config:
    """Configuration for model training"""
    # Paths
    PROCESSED_DATA_PATH = './data/processed/eeg_processed_data.h5'
    MODEL_DIR = './models/'
    RESULTS_DIR = './results/'
    
    # Model parameters
    N_CLASSES = 109  # Number of subjects
    N_CHANNELS = 64
    SEQUENCE_LENGTH = 480  # 3 seconds * 160 Hz
    
    # Training parameters
    BATCH_SIZE = 32
    EPOCHS = 100
    LEARNING_RATE = 0.001
    VALIDATION_SPLIT = 0.15
    TEST_SPLIT = 0.15
    
    # Create directories
    os.makedirs(MODEL_DIR, exist_ok=True)
    os.makedirs(RESULTS_DIR, exist_ok=True)

config = Config()
print("\nConfiguration loaded!")

#%% Load Processed Data

def load_processed_data(filepath):
    """
    Load preprocessed data from HDF5 file
    
    Returns:
    --------
    X_epochs : ndarray, shape (n_samples, n_channels, n_timesteps)
        Raw filtered EEG epochs
    X_spectrograms : ndarray, shape (n_samples, n_channels, n_freq, n_time)
        Spectrogram representations
    y : ndarray
        Subject labels
    """
    print("\n" + "="*60)
    print("LOADING PROCESSED DATA")
    print("="*60)
    
    with h5py.File(filepath, 'r') as hf:
        X_epochs = hf['X_epochs'][:]
        X_spectrograms = hf['X_spectrograms'][:]
        y_subjects = hf['y_subjects'][:]
        y_tasks = hf['y_tasks'][:]
        
        # Print metadata
        print("\nDataset Metadata:")
        for key, value in hf.attrs.items():
            print(f"  {key}: {value}")
    
    print("\nData loaded successfully!")
    print(f"  Epochs shape: {X_epochs.shape}")
    print(f"  Spectrograms shape: {X_spectrograms.shape}")
    print(f"  Subjects shape: {y_subjects.shape}")
    print(f"  Unique subjects: {len(np.unique(y_subjects))}")
    print(f"  Total samples: {len(y_subjects)}")
    
    return X_epochs, X_spectrograms, y_subjects

#%% Prepare Data for Training

def prepare_data(X_epochs, X_spectrograms, y_subjects, config):
    """
    Prepare and split data for training
    
    Returns:
    --------
    Data splits: train, validation, and test sets
    """
    print("\n" + "="*60)
    print("PREPARING DATA FOR TRAINING")
    print("="*60)
    
    # Encode labels (subjects 1-109 to 0-108)
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y_subjects)
    y_categorical = to_categorical(y_encoded, num_classes=config.N_CLASSES)
    
    print(f"\nLabel encoding:")
    print(f"  Original range: {y_subjects.min()} - {y_subjects.max()}")
    print(f"  Encoded range: {y_encoded.min()} - {y_encoded.max()}")
    print(f"  One-hot shape: {y_categorical.shape}")
    
    # Split data: 70% train, 15% validation, 15% test
    # First split: 70% train, 30% temp
    X_epoch_train, X_epoch_temp, X_spec_train, X_spec_temp, y_train, y_temp = train_test_split(
        X_epochs, X_spectrograms, y_categorical,
        test_size=0.30, random_state=42, stratify=y_encoded
    )
    
    # Second split: 50% of temp (15% of total) for val and test each
    X_epoch_val, X_epoch_test, X_spec_val, X_spec_test, y_val, y_test = train_test_split(
        X_epoch_temp, X_spec_temp, y_temp,
        test_size=0.50, random_state=42
    )
    
    print(f"\nData splits:")
    print(f"  Training set: {len(X_epoch_train)} samples ({len(X_epoch_train)/len(X_epochs)*100:.1f}%)")
    print(f"  Validation set: {len(X_epoch_val)} samples ({len(X_epoch_val)/len(X_epochs)*100:.1f}%)")
    print(f"  Test set: {len(X_epoch_test)} samples ({len(X_epoch_test)/len(X_epochs)*100:.1f}%)")
    
    # Reshape data for model inputs
    # RNN input: (batch, timesteps, channels)
    X_epoch_train = np.transpose(X_epoch_train, (0, 2, 1))
    X_epoch_val = np.transpose(X_epoch_val, (0, 2, 1))
    X_epoch_test = np.transpose(X_epoch_test, (0, 2, 1))
    
    # CNN input: (batch, height, width, channels) - treat channels as the "image"
    # Current: (batch, channels, freq, time) -> (batch, freq, time, channels)
    X_spec_train = np.transpose(X_spec_train, (0, 2, 3, 1))
    X_spec_val = np.transpose(X_spec_val, (0, 2, 3, 1))
    X_spec_test = np.transpose(X_spec_test, (0, 2, 3, 1))
    
    print(f"\nReshaped data:")
    print(f"  RNN input (train): {X_epoch_train.shape}")
    print(f"  CNN input (train): {X_spec_train.shape}")
    
    return (X_epoch_train, X_spec_train, y_train,
            X_epoch_val, X_spec_val, y_val,
            X_epoch_test, X_spec_test, y_test,
            label_encoder)

#%% Build CNN+RNN Hybrid Model

def build_hybrid_model(config, epoch_shape, spec_shape):
    """
    Build hybrid CNN+RNN model for person identification
    
    Architecture:
    1. RNN Branch: LSTM for temporal feature extraction
    2. CNN Branch: Conv2D for spatial-frequency feature extraction
    3. Fusion: Concatenate features from both branches
    4. Classification: Dense layers with softmax output
    
    Parameters:
    -----------
    config : Config
        Configuration object
    epoch_shape : tuple
        Shape of RNN input (timesteps, channels)
    spec_shape : tuple
        Shape of CNN input (freq, time, channels)
    
    Returns:
    --------
    model : keras.Model
        Compiled hybrid model
    """
    print("\n" + "="*60)
    print("BUILDING HYBRID CNN+RNN MODEL")
    print("="*60)
    
    # ==================== RNN BRANCH ====================
    # Input: (timesteps, channels)
    rnn_input = layers.Input(shape=epoch_shape, name='rnn_input')
    
    # Bidirectional LSTM layers
    x_rnn = layers.Bidirectional(
        layers.LSTM(128, return_sequences=True, dropout=0.3)
    )(rnn_input)
    x_rnn = layers.Bidirectional(
        layers.LSTM(64, return_sequences=False, dropout=0.3)
    )(x_rnn)
    
    # Dense layers for RNN branch
    x_rnn = layers.Dense(128, activation='relu')(x_rnn)
    x_rnn = layers.BatchNormalization()(x_rnn)
    x_rnn = layers.Dropout(0.4)(x_rnn)
    rnn_output = layers.Dense(64, activation='relu', name='rnn_features')(x_rnn)
    
    print("\n✓ RNN Branch built:")
    print("  - Bidirectional LSTM (128 units)")
    print("  - Bidirectional LSTM (64 units)")
    print("  - Dense layers with dropout")
    
    # ==================== CNN BRANCH ====================
    # Input: (freq, time, channels)
    cnn_input = layers.Input(shape=spec_shape, name='cnn_input')
    
    # Convolutional blocks
    x_cnn = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(cnn_input)
    x_cnn = layers.BatchNormalization()(x_cnn)
    x_cnn = layers.MaxPooling2D((2, 2))(x_cnn)
    x_cnn = layers.Dropout(0.3)(x_cnn)
    
    x_cnn = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(x_cnn)
    x_cnn = layers.BatchNormalization()(x_cnn)
    x_cnn = layers.MaxPooling2D((2, 2))(x_cnn)
    x_cnn = layers.Dropout(0.3)(x_cnn)
    
    x_cnn = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x_cnn)
    x_cnn = layers.BatchNormalization()(x_cnn)
    x_cnn = layers.MaxPooling2D((2, 2))(x_cnn)
    x_cnn = layers.Dropout(0.4)(x_cnn)
    
    # Global average pooling
    x_cnn = layers.GlobalAveragePooling2D()(x_cnn)
    
    # Dense layers for CNN branch
    x_cnn = layers.Dense(128, activation='relu')(x_cnn)
    x_cnn = layers.BatchNormalization()(x_cnn)
    x_cnn = layers.Dropout(0.4)(x_cnn)
    cnn_output = layers.Dense(64, activation='relu', name='cnn_features')(x_cnn)
    
    print("\n✓ CNN Branch built:")
    print("  - 3 Convolutional blocks (32, 64, 128 filters)")
    print("  - Batch normalization and dropout")
    print("  - Global average pooling")
    
    # ==================== FUSION & CLASSIFICATION ====================
    # Concatenate features from both branches
    merged = layers.concatenate([rnn_output, cnn_output], name='fusion')
    
    # Classification head
    x = layers.Dense(256, activation='relu')(merged)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    
    x = layers.Dense(128, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    
    # Output layer: 109 subjects
    output = layers.Dense(config.N_CLASSES, activation='softmax', name='output')(x)
    
    print("\n✓ Fusion & Classification layers built:")
    print("  - Concatenated feature dimension: 128")
    print("  - Dense layers (256, 128)")
    print("  - Output layer: 109 classes (subjects)")
    
    # Create model
    model = Model(inputs=[rnn_input, cnn_input], outputs=output, name='CNN_RNN_Hybrid')
    
    # Compile model
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=config.LEARNING_RATE),
        loss='categorical_crossentropy',
        metrics=['accuracy', tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top5_accuracy')]
    )
    
    print("\n✓ Model compiled with:")
    print(f"  - Optimizer: Adam (lr={config.LEARNING_RATE})")
    print(f"  - Loss: Categorical Crossentropy")
    print(f"  - Metrics: Accuracy, Top-5 Accuracy")
    
    return model

#%% Training Callbacks

def create_callbacks(config):
    """
    Create training callbacks
    
    Returns:
    --------
    callbacks : list
        List of Keras callbacks
    """
    callbacks = [
        # Save best model
        ModelCheckpoint(
            filepath=os.path.join(config.MODEL_DIR, 'best_model.h5'),
            monitor='val_accuracy',
            save_best_only=True,
            mode='max',
            verbose=1
        ),
        
        # Early stopping
        EarlyStopping(
            monitor='val_loss',
            patience=15,
            restore_best_weights=True,
            verbose=1
        ),
        
        # Reduce learning rate on plateau
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=7,
            min_lr=1e-6,
            verbose=1
        ),
        
        # TensorBoard logging
        TensorBoard(
            log_dir=os.path.join(config.RESULTS_DIR, 'logs'),
            histogram_freq=1
        )
    ]
    
    return callbacks

#%% Train Model

def train_model(model, X_epoch_train, X_spec_train, y_train,
                X_epoch_val, X_spec_val, y_val, config):
    """
    Train the hybrid model
    
    Returns:
    --------
    history : History object
        Training history
    """
    print("\n" + "="*60)
    print("TRAINING MODEL")
    print("="*60)
    
    callbacks = create_callbacks(config)
    
    print(f"\nStarting training:")
    print(f"  Batch size: {config.BATCH_SIZE}")
    print(f"  Epochs: {config.EPOCHS}")
    print(f"  Training samples: {len(X_epoch_train)}")
    print(f"  Validation samples: {len(X_epoch_val)}")
    
    history = model.fit(
        [X_epoch_train, X_spec_train],
        y_train,
        batch_size=config.BATCH_SIZE,
        epochs=config.EPOCHS,
        validation_data=([X_epoch_val, X_spec_val], y_val),
        callbacks=callbacks,
        verbose=1
    )
    
    print("\n✓ Training complete!")
    
    return history

#%% Plot Training History

def plot_training_history(history, config):
    """
    Plot training and validation metrics
    """
    print("\n" + "="*60)
    print("PLOTTING TRAINING HISTORY")
    print("="*60)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot 1: Accuracy
    ax1 = axes[0, 0]
    ax1.plot(history.history['accuracy'], label='Train Accuracy', linewidth=2)
    ax1.plot(history.history['val_accuracy'], label='Val Accuracy', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Loss
    ax2 = axes[0, 1]
    ax2.plot(history.history['loss'], label='Train Loss', linewidth=2)
    ax2.plot(history.history['val_loss'], label='Val Loss', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.set_title('Model Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Top-5 Accuracy
    ax3 = axes[1, 0]
    ax3.plot(history.history['top5_accuracy'], label='Train Top-5', linewidth=2)
    ax3.plot(history.history['val_top5_accuracy'], label='Val Top-5', linewidth=2)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Top-5 Accuracy')
    ax3.set_title('Top-5 Accuracy')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Plot 4: Learning Rate
    ax4 = axes[1, 1]
    if 'lr' in history.history:
        ax4.plot(history.history['lr'], linewidth=2, color='coral')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Learning Rate')
        ax4.set_title('Learning Rate Schedule')
        ax4.set_yscale('log')
        ax4.grid(True, alpha=0.3)
    else:
        ax4.text(0.5, 0.5, 'Learning rate not logged', 
                ha='center', va='center', transform=ax4.transAxes)
        ax4.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(config.RESULTS_DIR, 'training_history.png'), 
                dpi=300, bbox_inches='tight')
    print("\n✓ Training history plot saved!")
    plt.show()

#%% Save Training Results

def save_training_results(history, model, config):
    """
    Save training history and model summary
    """
    # Save history to CSV
    history_df = pd.DataFrame(history.history)
    history_df.to_csv(os.path.join(config.RESULTS_DIR, 'training_history.csv'), index=False)
    
    # Save model summary
    with open(os.path.join(config.RESULTS_DIR, 'model_summary.txt'), 'w') as f:
        model.summary(print_fn=lambda x: f.write(x + '\n'))
    
    print("\n✓ Training results saved!")

#%% Main Execution

if __name__ == "__main__":
    print("\n" + "="*60)
    print("CNN+RNN HYBRID MODEL - TRAINING PIPELINE")
    print("="*60)
    
    # Load data
    X_epochs, X_spectrograms, y_subjects = load_processed_data(config.PROCESSED_DATA_PATH)
    
    # Prepare data
    (X_epoch_train, X_spec_train, y_train,
     X_epoch_val, X_spec_val, y_val,
     X_epoch_test, X_spec_test, y_test,
     label_encoder) = prepare_data(X_epochs, X_spectrograms, y_subjects, config)
    
    # Build model
    epoch_shape = X_epoch_train.shape[1:]  # (timesteps, channels)
    spec_shape = X_spec_train.shape[1:]    # (freq, time, channels)
    
    model = build_hybrid_model(config, epoch_shape, spec_shape)
    
    # Print model summary
    print("\n" + "="*60)
    print("MODEL SUMMARY")
    print("="*60)
    model.summary()
    
    # Calculate total parameters
    total_params = model.count_params()
    print(f"\nTotal Parameters: {total_params:,}")
    
    # Train model
    history = train_model(model, X_epoch_train, X_spec_train, y_train,
                         X_epoch_val, X_spec_val, y_val, config)
    
    # Plot training history
    plot_training_history(history, config)
    
    # Save results
    save_training_results(history, model, config)
    
    # Save test data for evaluation
    np.savez(os.path.join(config.RESULTS_DIR, 'test_data.npz'),
             X_epoch_test=X_epoch_test,
             X_spec_test=X_spec_test,
             y_test=y_test)
    
    # Save label encoder
    import pickle
    with open(os.path.join(config.RESULTS_DIR, 'label_encoder.pkl'), 'wb') as f:
        pickle.dump(label_encoder, f)
    
    print("\n" + "="*60)
    print("TRAINING COMPLETE!")
    print("="*60)
    print(f"\nBest model saved to: {os.path.join(config.MODEL_DIR, 'best_model.h5')}")
    print(f"Results saved to: {config.RESULTS_DIR}")
    print("\nNext step: Run 04_evaluation_visualization.ipynb")

#%% Training Summary

"""
MODEL TRAINING SUMMARY
======================

Hybrid CNN+RNN Architecture:
- RNN Branch: Bidirectional LSTM (128→64 units) for temporal features
- CNN Branch: 3 Conv2D blocks (32, 64, 128 filters) for spatial-frequency features
- Fusion: Concatenated features (128 dimensions)
- Classification: Dense layers (256, 128) → 109-class softmax

Training Configuration:
- Optimizer: Adam (lr=0.001 with ReduceLROnPlateau)
- Loss: Categorical Crossentropy
- Metrics: Accuracy, Top-5 Accuracy
- Batch size: 32
- Epochs: 100 (with early stopping)
- Data split: 70% train, 15% val, 15% test

Regularization:
- Dropout layers (0.3-0.5)
- Batch normalization
- Early stopping (patience=15)
- Learning rate reduction (patience=7)

Output Files:
- best_model.h5: Best model weights
- training_history.csv: Training metrics
- training_history.png: Training plots
- model_summary.txt: Model architecture
- test_data.npz: Test set for evaluation
- label_encoder.pkl: Label encoder for decoding

The model is now ready for evaluation!
"""