# Attention-Enhanced BiLSTM-CNN Model
This notebook implements an advanced hybrid architecture combining:
- **Bidirectional LSTM** for temporal context
- **1D CNN** for local feature extraction
- **Attention Mechanism** for important feature focus
- **Skip Connections** for gradient flow

## Model Architecture Benefits:
- ‚úÖ Captures both past and future context (BiLSTM)
- ‚úÖ Learns hierarchical features (CNN)
- ‚úÖ Focuses on discriminative patterns (Attention)
- ‚úÖ Better gradient propagation (Skip connections)
- ‚úÖ Reduces overfitting (Dropout, BatchNorm)

In [None]:
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.layers as tfl
from tensorflow.keras.layers import *
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from time import time

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

## 1. Attention Layer Implementation

In [None]:
class AttentionLayer(Layer):
    """
    Custom Attention Layer for time series
    
    This layer computes attention weights for each time step,
    allowing the model to focus on the most relevant parts of the signal.
    """
    
    def __init__(self, **kwargs):
        super(AttentionLayer, self).__init__(**kwargs)
    
    def build(self, input_shape):
        self.W = self.add_weight(name='attention_weight',
                                 shape=(input_shape[-1], input_shape[-1]),
                                 initializer='glorot_uniform',
                                 trainable=True)
        self.b = self.add_weight(name='attention_bias',
                                 shape=(input_shape[-1],),
                                 initializer='zeros',
                                 trainable=True)
        self.u = self.add_weight(name='attention_context',
                                 shape=(input_shape[-1],),
                                 initializer='glorot_uniform',
                                 trainable=True)
        super(AttentionLayer, self).build(input_shape)
    
    def call(self, x):
        # Compute attention scores
        uit = tf.tanh(tf.tensordot(x, self.W, axes=1) + self.b)
        ait = tf.tensordot(uit, self.u, axes=1)
        
        # Apply softmax to get attention weights
        attention_weights = tf.nn.softmax(ait, axis=1)
        attention_weights = tf.expand_dims(attention_weights, axis=-1)
        
        # Apply attention weights
        weighted_input = x * attention_weights
        
        return tf.reduce_sum(weighted_input, axis=1), attention_weights
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[-1])
    
    def get_config(self):
        return super(AttentionLayer, self).get_config()

print("‚úÖ Attention layer defined")

## 2. Attention-Enhanced BiLSTM-CNN Model

In [None]:
def build_attention_bilstm_cnn(input_shape, use_attention=True):
    """
    Build Attention-Enhanced BiLSTM-CNN model for sleep disorder classification
    
    Architecture:
    1. Initial CNN block for local feature extraction
    2. Skip connection block with batch normalization
    3. Bidirectional LSTM for temporal dependencies
    4. Attention mechanism to focus on important features
    5. Dense layers for classification
    
    Parameters:
    -----------
    input_shape : tuple
        Shape of input (timesteps, features)
    use_attention : bool
        Whether to use attention mechanism
    
    Returns:
    --------
    model : keras.Model
        Compiled model
    """
    
    input_signal = Input(shape=input_shape, name='input')
    
    # ====== Block 1: Initial Feature Extraction ======
    x = Conv1D(filters=32, kernel_size=7, strides=1, padding='same', name='conv1_1')(input_signal)
    x = BatchNormalization(name='bn1_1')(x)
    x = Activation('relu', name='relu1_1')(x)
    
    x = Conv1D(filters=32, kernel_size=7, strides=1, padding='same', name='conv1_2')(x)
    x = BatchNormalization(name='bn1_2')(x)
    x = Activation('relu', name='relu1_2')(x)
    
    # Skip connection 1
    skip1 = x
    
    # ====== Block 2: Deeper Feature Extraction with Skip ======
    x = Conv1D(filters=32, kernel_size=9, strides=1, padding='same', name='conv2_1')(x)
    x = BatchNormalization(name='bn2_1')(x)
    x = Activation('relu', name='relu2_1')(x)
    
    x = Conv1D(filters=32, kernel_size=9, strides=1, padding='same', name='conv2_2')(x)
    x = BatchNormalization(name='bn2_2')(x)
    
    # Add skip connection
    x = Add(name='skip_add_1')([x, skip1])
    x = Activation('relu', name='relu2_3')(x)
    
    # ====== Block 3: Downsampling ======
    x = Conv1D(filters=64, kernel_size=9, strides=1, padding='same', name='conv3_1')(x)
    x = BatchNormalization(name='bn3_1')(x)
    x = Activation('relu', name='relu3_1')(x)
    x = MaxPooling1D(pool_size=2, padding='same', name='pool1')(x)
    x = Dropout(0.3, name='dropout1')(x)
    
    # ====== Block 4: More CNN layers ======
    x = Conv1D(filters=64, kernel_size=7, strides=1, padding='same', name='conv4_1')(x)
    x = BatchNormalization(name='bn4_1')(x)
    x = Activation('relu', name='relu4_1')(x)
    
    x = Conv1D(filters=32, kernel_size=5, strides=1, padding='same', name='conv4_2')(x)
    x = BatchNormalization(name='bn4_2')(x)
    x = Activation('relu', name='relu4_2')(x)
    x = MaxPooling1D(pool_size=2, padding='same', name='pool2')(x)
    
    # ====== Block 5: Bidirectional LSTM ======
    x = Bidirectional(LSTM(64, return_sequences=True, dropout=0.2, recurrent_dropout=0.2), 
                      name='bilstm1')(x)
    x = BatchNormalization(name='bn_lstm1')(x)
    
    x = Bidirectional(LSTM(32, return_sequences=True, dropout=0.2, recurrent_dropout=0.2), 
                      name='bilstm2')(x)
    x = BatchNormalization(name='bn_lstm2')(x)
    
    # ====== Block 6: Attention Mechanism ======
    if use_attention:
        x, attention_weights = AttentionLayer(name='attention')(x)
    else:
        x = GlobalAveragePooling1D(name='global_avg_pool')(x)
    
    # ====== Block 7: Dense Classification Layers ======
    x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.01), 
              name='dense1')(x)
    x = Dropout(0.4, name='dropout2')(x)
    
    x = Dense(32, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.01), 
              name='dense2')(x)
    x = Dropout(0.3, name='dropout3')(x)
    
    # Output layer
    output = Dense(1, activation='sigmoid', name='output')(x)
    
    # Build model
    model = keras.Model(inputs=input_signal, outputs=output, name='AttentionBiLSTM_CNN')
    
    return model

print("‚úÖ Model architecture defined")

## 3. Compile Model with Advanced Optimizer

In [None]:
def compile_attention_model(model, learning_rate=0.001):
    """
    Compile model with Adam optimizer and binary crossentropy
    
    Parameters:
    -----------
    model : keras.Model
        Model to compile
    learning_rate : float
        Initial learning rate
    
    Returns:
    --------
    model : keras.Model
        Compiled model
    """
    optimizer = keras.optimizers.Adam(
        learning_rate=learning_rate,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-07
    )
    
    model.compile(
        optimizer=optimizer,
        loss='binary_crossentropy',
        metrics=['accuracy', 
                 tf.keras.metrics.Precision(name='precision'),
                 tf.keras.metrics.Recall(name='recall'),
                 tf.keras.metrics.AUC(name='auc')]
    )
    
    return model

print("‚úÖ Compilation function defined")

## 4. Training Function with Callbacks

In [None]:
def train_attention_model(model, x_train, y_train, x_val=None, y_val=None, 
                         epochs=150, batch_size=64, save_path=None):
    """
    Train the attention model with advanced callbacks
    
    Parameters:
    -----------
    model : keras.Model
        Model to train
    x_train, y_train : arrays
        Training data
    x_val, y_val : arrays or None
        Validation data (if None, will use validation_split)
    epochs : int
        Maximum number of epochs
    batch_size : int
        Batch size
    save_path : str or None
        Path to save best model
    
    Returns:
    --------
    history : History object
        Training history
    training_time : float
        Total training time in seconds
    """
    
    # Define callbacks
    callbacks_list = []
    
    # Early stopping
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=30,
        verbose=1,
        restore_best_weights=True,
        mode='min'
    )
    callbacks_list.append(early_stopping)
    
    # Reduce learning rate on plateau
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=10,
        min_lr=1e-6,
        verbose=1,
        mode='min'
    )
    callbacks_list.append(reduce_lr)
    
    # Model checkpoint
    if save_path:
        checkpoint = ModelCheckpoint(
            save_path,
            monitor='val_accuracy',
            verbose=1,
            save_best_only=True,
            mode='max'
        )
        callbacks_list.append(checkpoint)
    
    # Start training
    start_time = time()
    
    if x_val is not None and y_val is not None:
        # Use provided validation data
        history = model.fit(
            x_train, y_train,
            validation_data=(x_val, y_val),
            epochs=epochs,
            batch_size=batch_size,
            callbacks=callbacks_list,
            verbose=1,
            shuffle=True
        )
    else:
        # Use validation split
        history = model.fit(
            x_train, y_train,
            validation_split=0.2,
            epochs=epochs,
            batch_size=batch_size,
            callbacks=callbacks_list,
            verbose=1,
            shuffle=True
        )
    
    training_time = time() - start_time
    
    print(f"\n‚úÖ Training completed in {training_time:.2f} seconds ({training_time/60:.2f} minutes)")
    print(f"üìä Final training accuracy: {history.history['accuracy'][-1]*100:.2f}%")
    print(f"üìä Final validation accuracy: {history.history['val_accuracy'][-1]*100:.2f}%")
    
    return history, training_time

print("‚úÖ Training function defined")

## 5. Simple CNN-LSTM Baseline (for comparison)

In [None]:
def build_simple_cnn_lstm(input_shape):
    """
    Build a simpler CNN-LSTM baseline model for comparison
    """
    model = keras.Sequential([
        Input(shape=input_shape),
        Conv1D(32, 7, activation='relu', padding='same'),
        MaxPooling1D(2),
        Conv1D(64, 5, activation='relu', padding='same'),
        MaxPooling1D(2),
        LSTM(64, return_sequences=True),
        LSTM(32, return_sequences=False),
        Dense(32, activation='relu'),
        Dropout(0.3),
        Dense(1, activation='sigmoid')
    ], name='Simple_CNN_LSTM')
    
    return model

print("‚úÖ Baseline model defined")

## 6. Example Usage

In [None]:
# Example: Create and compile model
# Uncomment to test

# input_shape = (1024, 1)  # 1024 time steps, 1 feature
# 
# # Build attention model
# model = build_attention_bilstm_cnn(input_shape, use_attention=True)
# model = compile_attention_model(model, learning_rate=0.001)
# 
# # Display model summary
# model.summary()
# 
# # Count parameters
# total_params = model.count_params()
# print(f"\nüìä Total parameters: {total_params:,}")

## 7. Model Visualization

In [None]:
def visualize_model_architecture(model, save_path='model_architecture.png'):
    """
    Visualize and save model architecture
    """
    try:
        from tensorflow.keras.utils import plot_model
        plot_model(model, to_file=save_path, show_shapes=True, 
                   show_layer_names=True, rankdir='TB', dpi=96)
        print(f"‚úÖ Model architecture saved to {save_path}")
    except Exception as e:
        print(f"‚ö†Ô∏è Could not visualize model: {e}")
        print("Install graphviz and pydot if needed: pip install pydot graphviz")

print("\n" + "="*70)
print("‚úÖ Attention-Enhanced BiLSTM-CNN Model loaded successfully!")
print("="*70)
print("\nAvailable functions:")
print("  - build_attention_bilstm_cnn(input_shape, use_attention)")
print("  - compile_attention_model(model, learning_rate)")
print("  - train_attention_model(model, x_train, y_train, ...)")
print("  - build_simple_cnn_lstm(input_shape) [baseline]")
print("  - visualize_model_architecture(model, save_path)")
print("\n" + "="*70)