# üöÄ Hybrid CNN-Transformer Sleep Disorder Classification

## üéØ Architecture Overview:
This notebook implements a **hybrid CNN-Transformer model** that combines:
- **CNN layers** for local temporal feature extraction
- **Transformer encoders** for capturing global dependencies
- **Multi-head attention** for learning multiple representation subspaces

## üìå Key Advantages:
‚úÖ **Global context** - Transformers see the entire sequence at once  
‚úÖ **Parallel processing** - Faster than sequential LSTM  
‚úÖ **Local patterns** - CNN extracts features efficiently  
‚úÖ **Positional encoding** - Preserves temporal information  

---

## 1Ô∏è‚É£ Mount Google Drive (For Colab)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

print("‚úÖ Google Drive mounted successfully!")
print("\nYour files are now accessible at: /content/drive/MyDrive/")

## 2Ô∏è‚É£ Install & Import Required Libraries

In [None]:
# Install packages if needed
# !pip install -q tensorflow scikit-learn matplotlib seaborn pandas numpy

print("‚úÖ All packages ready!")

In [None]:
# Standard libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, roc_curve, auc, confusion_matrix, classification_report
)
import warnings
warnings.filterwarnings('ignore')

# Deep learning
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import (
    Input, Dense, Conv1D, MaxPooling1D, Flatten, Dropout,
    BatchNormalization, Activation, GlobalAveragePooling1D, Layer
)
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from time import time

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f"‚úÖ Imports successful!")
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {len(tf.config.list_physical_devices('GPU')) > 0}")

%matplotlib inline

## 3Ô∏è‚É£ Data Augmentation Functions

In [None]:
def jitter(x, sigma=0.03):
    """Add random Gaussian noise"""
    noise = np.random.normal(loc=0., scale=sigma * np.std(x), size=x.shape)
    return x + noise

def scaling(x, sigma=0.1):
    """Randomly scale the signal amplitude"""
    factor = np.random.normal(loc=1., scale=sigma)
    return x * factor

def time_warp(x, sigma=0.2, knot=4):
    """Apply time warping to the signal"""
    orig_steps = np.arange(x.shape[0])
    random_warps = np.random.normal(loc=1.0, scale=sigma, size=(knot+2,))
    warp_steps = (np.linspace(0, x.shape[0]-1, num=knot+2))
    ret = np.interp(orig_steps, warp_steps, random_warps)
    ret = ret / ret.sum() * x.shape[0]
    ret = np.cumsum(ret)
    if len(x.shape) == 1:
        return np.interp(orig_steps, ret, x)
    else:
        return np.array([np.interp(orig_steps, ret, x[:, i]) for i in range(x.shape[1])]).T

def augment_signal(x, augmentation_list=['jitter', 'scaling', 'time_warp'], n_augmentations=2):
    """Apply random augmentations to a signal"""
    augmented = x.copy()
    selected = np.random.choice(augmentation_list, size=min(n_augmentations, len(augmentation_list)), replace=False)
    for aug in selected:
        if aug == 'jitter':
            augmented = jitter(augmented)
        elif aug == 'scaling':
            augmented = scaling(augmented)
        elif aug == 'time_warp':
            augmented = time_warp(augmented)
    return augmented

def augment_dataset(X, y, augmentation_factor=1):
    """Augment entire dataset"""
    X_aug_list = [X]
    y_aug_list = [y]
    for i in range(augmentation_factor):
        X_new = np.array([augment_signal(x) for x in X])
        X_aug_list.append(X_new)
        y_aug_list.append(y)
    X_aug = np.concatenate(X_aug_list, axis=0)
    y_aug = np.concatenate(y_aug_list, axis=0)
    indices = np.random.permutation(len(X_aug))
    X_aug = X_aug[indices]
    y_aug = y_aug[indices]
    print(f"Original dataset size: {len(X)}")
    print(f"Augmented dataset size: {len(X_aug)}")
    return X_aug, y_aug

print("‚úÖ Data augmentation functions loaded!")

## 4Ô∏è‚É£ Hybrid CNN-Transformer Architecture üöÄ

In [None]:
# Positional Encoding Layer
class PositionalEncoding(Layer):
    def __init__(self, **kwargs):
        super(PositionalEncoding, self).__init__(**kwargs)
    
    def build(self, input_shape):
        seq_len = input_shape[1]
        d_model = input_shape[2]
        
        # Create positional encoding matrix
        position = np.arange(seq_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        
        pe = np.zeros((seq_len, d_model))
        pe[:, 0::2] = np.sin(position * div_term)
        if d_model % 2 == 0:
            pe[:, 1::2] = np.cos(position * div_term)
        else:
            pe[:, 1::2] = np.cos(position * div_term[:-1])
        
        self.pe = tf.constant(pe, dtype=tf.float32)
        super(PositionalEncoding, self).build(input_shape)
    
    def call(self, x):
        return x + self.pe
    
    def get_config(self):
        return super(PositionalEncoding, self).get_config()


# Transformer Encoder Block
def transformer_encoder_block(x, num_heads=4, ff_dim=128, dropout_rate=0.1, name_prefix='transformer'):
    """
    Single Transformer Encoder Block with Multi-Head Attention
    """
    # Multi-Head Attention
    attn_output = layers.MultiHeadAttention(
        num_heads=num_heads,
        key_dim=x.shape[-1] // num_heads,
        dropout=dropout_rate,
        name=f'{name_prefix}_mha'
    )(x, x)
    attn_output = Dropout(dropout_rate, name=f'{name_prefix}_dropout1')(attn_output)
    out1 = layers.LayerNormalization(epsilon=1e-6, name=f'{name_prefix}_ln1')(x + attn_output)
    
    # Feed-Forward Network
    ffn_output = Dense(ff_dim, activation='relu', name=f'{name_prefix}_ffn1')(out1)
    ffn_output = Dropout(dropout_rate, name=f'{name_prefix}_dropout2')(ffn_output)
    ffn_output = Dense(x.shape[-1], name=f'{name_prefix}_ffn2')(ffn_output)
    ffn_output = Dropout(dropout_rate, name=f'{name_prefix}_dropout3')(ffn_output)
    out2 = layers.LayerNormalization(epsilon=1e-6, name=f'{name_prefix}_ln2')(out1 + ffn_output)
    
    return out2


# Build Hybrid CNN-Transformer Model
def build_cnn_transformer(input_shape, num_transformer_blocks=2, num_heads=4, ff_dim=128):
    """
    Hybrid CNN-Transformer Architecture:
    - CNN layers extract local temporal features
    - Transformer blocks capture global dependencies
    - Best of both worlds!
    """
    input_signal = Input(shape=input_shape, name='input')
    
    # ============ CNN Feature Extraction ============
    # Block 1: Initial convolutions
    x = Conv1D(filters=64, kernel_size=7, strides=1, padding='same', name='cnn_conv1')(input_signal)
    x = BatchNormalization(name='cnn_bn1')(x)
    x = Activation('relu', name='cnn_relu1')(x)
    x = MaxPooling1D(pool_size=2, padding='same', name='cnn_pool1')(x)
    x = Dropout(0.2, name='cnn_dropout1')(x)
    
    # Block 2: Deeper features
    x = Conv1D(filters=128, kernel_size=5, strides=1, padding='same', name='cnn_conv2')(x)
    x = BatchNormalization(name='cnn_bn2')(x)
    x = Activation('relu', name='cnn_relu2')(x)
    x = MaxPooling1D(pool_size=2, padding='same', name='cnn_pool2')(x)
    x = Dropout(0.2, name='cnn_dropout2')(x)
    
    # Block 3: Feature refinement
    x = Conv1D(filters=128, kernel_size=3, strides=1, padding='same', name='cnn_conv3')(x)
    x = BatchNormalization(name='cnn_bn3')(x)
    x = Activation('relu', name='cnn_relu3')(x)
    
    # ============ Positional Encoding ============
    x = PositionalEncoding(name='pos_encoding')(x)
    
    # ============ Transformer Encoder Blocks ============
    for i in range(num_transformer_blocks):
        x = transformer_encoder_block(
            x,
            num_heads=num_heads,
            ff_dim=ff_dim,
            dropout_rate=0.1,
            name_prefix=f'transformer_block_{i+1}'
        )
    
    # ============ Global Pooling ============
    x = GlobalAveragePooling1D(name='global_avg_pool')(x)
    
    # ============ Classification Head ============
    x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.01), name='dense1')(x)
    x = Dropout(0.4, name='dropout_final1')(x)
    x = Dense(32, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.01), name='dense2')(x)
    x = Dropout(0.3, name='dropout_final2')(x)
    output = Dense(1, activation='sigmoid', name='output')(x)
    
    model = keras.Model(inputs=input_signal, outputs=output, name='CNN_Transformer_Hybrid')
    return model

print("‚úÖ CNN-Transformer hybrid architecture defined!")
print("   üîπ CNN layers: Extract local temporal patterns")
print("   üîπ Positional encoding: Preserve sequence order")
print("   üîπ Transformer blocks: Capture global dependencies")
print("   üîπ Multi-head attention: Learn multiple representation subspaces")

## 5Ô∏è‚É£ Visualization Functions

In [None]:
def plot_training_history(history, title='Training History'):
    """Plot training and validation metrics"""
    if hasattr(history, 'history'):
        history = history.history
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))
    
    # Accuracy
    axes[0].plot(history['accuracy'], 'b-', linewidth=2, label='Training')
    axes[0].plot(history['val_accuracy'], 'r-', linewidth=2, label='Validation')
    axes[0].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[0].set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    axes[0].set_title('Model Accuracy', fontsize=14, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Loss
    axes[1].plot(history['loss'], 'b-', linewidth=2, label='Training')
    axes[1].plot(history['val_loss'], 'r-', linewidth=2, label='Validation')
    axes[1].set_xlabel('Epoch', fontsize=12, fontweight='bold')
    axes[1].set_ylabel('Loss', fontsize=12, fontweight='bold')
    axes[1].set_title('Model Loss', fontsize=14, fontweight='bold')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(y_true, y_pred, class_names=['Healthy', 'Unhealthy']):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                linewidths=2, linecolor='white', ax=ax)
    
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j + 0.5, i + 0.7, f'({cm[i, j]})',
                   ha='center', va='center', fontsize=10, color='gray')
    
    ax.set_ylabel('True Label', fontsize=13, fontweight='bold')
    ax.set_xlabel('Predicted Label', fontsize=13, fontweight='bold')
    ax.set_title('Confusion Matrix', fontsize=15, fontweight='bold')
    plt.tight_layout()
    plt.show()

def plot_roc_curve(y_true, y_pred_proba):
    """Plot ROC curve"""
    fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=3, label=f'ROC curve (AUC = {roc_auc:.3f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random')
    plt.fill_between(fpr, tpr, 0, alpha=0.2, color='orange')
    plt.xlabel('False Positive Rate', fontsize=12, fontweight='bold')
    plt.ylabel('True Positive Rate', fontsize=12, fontweight='bold')
    plt.title('ROC Curve', fontsize=13, fontweight='bold')
    plt.legend(loc='lower right')
    plt.grid(True, alpha=0.3)
    plt.show()
    return roc_auc

def generate_metrics_report(y_true, y_pred, y_pred_proba):
    """Generate comprehensive metrics"""
    cm = confusion_matrix(y_true, y_pred)
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    roc_auc = roc_auc_score(y_true, y_pred_proba)
    
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    else:
        specificity = 0
    
    print("\n" + "="*70)
    print("üìä PERFORMANCE METRICS")
    print("="*70)
    print(f"Accuracy:     {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Precision:    {precision:.4f} ({precision*100:.2f}%)")
    print(f"Recall:       {recall:.4f} ({recall*100:.2f}%)")
    print(f"Specificity:  {specificity:.4f} ({specificity*100:.2f}%)")
    print(f"F1-Score:     {f1:.4f} ({f1*100:.2f}%)")
    print(f"ROC-AUC:      {roc_auc:.4f} ({roc_auc*100:.2f}%)")
    print("="*70 + "\n")
    
    print("CLASSIFICATION REPORT:")
    print(classification_report(y_true, y_pred, target_names=['Healthy', 'Unhealthy'], digits=4))
    
    return {
        'Accuracy': accuracy,
        'Precision': precision,
        'Recall': recall,
        'Specificity': specificity,
        'F1-Score': f1,
        'ROC-AUC': roc_auc
    }

print("‚úÖ Visualization functions loaded!")

## 6Ô∏è‚É£ Load Dataset

**‚ö†Ô∏è IMPORTANT: Update DATA_PATH below**

In [None]:
# TODO: Update this path to your dataset location
DATA_PATH = '/content/drive/MyDrive/your_folder/healthy_unhealthy1.csv'

# Load data
data = np.loadtxt(DATA_PATH, delimiter=',')
print(f"‚úÖ Data loaded successfully!")
print(f"   Shape: {data.shape}")

# Split features and labels
X = data[:, 0:1024]  # First 1024 columns
y = data[:, -1]      # Last column (label)

print(f"\nüìä Dataset Statistics:")
print(f"   Total samples: {len(X)}")
print(f"   Feature dimensions: {X.shape[1]}")
print(f"   Healthy samples: {np.sum(y == 0)} ({np.sum(y == 0)/len(y)*100:.1f}%)")
print(f"   Unhealthy samples: {np.sum(y == 1)} ({np.sum(y == 1)/len(y)*100:.1f}%)")

## 7Ô∏è‚É£ Data Preparation & Augmentation

In [None]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, shuffle=True, stratify=y, random_state=42
)

print(f"üìä Data Split:")
print(f"   Training: {len(X_train)} samples")
print(f"   Test: {len(X_test)} samples")

In [None]:
# Apply augmentation
APPLY_AUGMENTATION = True
AUGMENTATION_FACTOR = 1  # Increase for more augmented data

if APPLY_AUGMENTATION and AUGMENTATION_FACTOR > 0:
    print("üîÑ Applying data augmentation...")
    X_train_aug, y_train_aug = augment_dataset(X_train, y_train, AUGMENTATION_FACTOR)
else:
    X_train_aug, y_train_aug = X_train, y_train

# Reshape for CNN-Transformer input
X_train_aug = X_train_aug.reshape(-1, 1024, 1)
X_test_reshaped = X_test.reshape(-1, 1024, 1)

print(f"\n‚úÖ Final shapes:")
print(f"   Training: {X_train_aug.shape}")
print(f"   Test: {X_test_reshaped.shape}")

## 8Ô∏è‚É£ Build & Compile Model

In [None]:
# Build CNN-Transformer model
input_shape = (1024, 1)

print("üèóÔ∏è Building CNN-Transformer model...\n")
model = build_cnn_transformer(
    input_shape,
    num_transformer_blocks=2,  # Number of transformer encoder blocks
    num_heads=4,                # Number of attention heads
    ff_dim=128                  # Feed-forward dimension
)

# Compile
optimizer = keras.optimizers.Adam(learning_rate=0.001)
model.compile(
    optimizer=optimizer,
    loss='binary_crossentropy',
    metrics=['accuracy']
)

model.summary()
print(f"\n‚úÖ Model built successfully!")
print(f"   Total parameters: {model.count_params():,}")

## 9Ô∏è‚É£ Train Model

In [None]:
# Training configuration
EPOCHS = 150
BATCH_SIZE = 64
MODEL_SAVE_PATH = '/content/drive/MyDrive/cnn_transformer_model.h5'

# Callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=30,
        verbose=1,
        restore_best_weights=True
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=10,
        min_lr=1e-6,
        verbose=1
    ),
    ModelCheckpoint(
        MODEL_SAVE_PATH,
        monitor='val_accuracy',
        verbose=1,
        save_best_only=True,
        mode='max'
    )
]

print("üöÄ Starting training...\n")
start_time = time()

history = model.fit(
    X_train_aug, y_train_aug,
    validation_split=0.2,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1,
    shuffle=True
)

training_time = time() - start_time

print(f"\n‚úÖ Training completed!")
print(f"   Total time: {training_time:.2f}s ({training_time/60:.2f} min)")
print(f"   Model saved to: {MODEL_SAVE_PATH}")

## üîü Evaluate Model

In [None]:
# Generate predictions
print("üìä Generating predictions...\n")
y_pred_proba = model.predict(X_test_reshaped, verbose=0).flatten()
y_pred = (y_pred_proba > 0.5).astype(int)

# Plot training history
plot_training_history(history, 'CNN-Transformer Training')

# Plot confusion matrix
plot_confusion_matrix(y_test, y_pred)

# Plot ROC curve
plot_roc_curve(y_test, y_pred_proba)

# Generate metrics report
metrics = generate_metrics_report(y_test, y_pred, y_pred_proba)

## 1Ô∏è‚É£1Ô∏è‚É£ Final Summary

In [None]:
print("\n" + "="*80)
print("üéä CNN-TRANSFORMER MODEL SUMMARY")
print("="*80)

print("\nüìä Dataset:")
print(f"   Total samples: {len(X)}")
print(f"   Training samples: {len(X_train_aug)}")
print(f"   Test samples: {len(X_test)}")

print("\nüèÜ Model Performance:")
print(f"   Accuracy: {metrics['Accuracy']*100:.2f}%")
print(f"   F1-Score: {metrics['F1-Score']*100:.2f}%")
print(f"   ROC-AUC: {metrics['ROC-AUC']*100:.2f}%")
print(f"   Precision: {metrics['Precision']*100:.2f}%")
print(f"   Recall: {metrics['Recall']*100:.2f}%")

print("\nüíæ Saved Artifacts:")
print(f"   Model: {MODEL_SAVE_PATH}")

print("\nüéØ Architecture Features:")
print("   ‚úÖ CNN layers for local feature extraction")
print("   ‚úÖ Positional encoding for temporal information")
print("   ‚úÖ Multi-head attention for global dependencies")
print("   ‚úÖ Transformer encoders for sequence modeling")
print("   ‚úÖ Layer normalization for training stability")

print("\n" + "="*80)
print("‚úÖ CNN-TRANSFORMER CLASSIFICATION COMPLETED!")
print("="*80 + "\n")

## üìù Next Steps

### Model Improvements:
- Adjust `num_transformer_blocks` (2-4 blocks)
- Experiment with `num_heads` (4, 8, or 16)
- Tune `ff_dim` (feed-forward dimension)
- Try different learning rates
- Increase augmentation factor

### Analysis:
- Compare with BiLSTM-CNN-Attention model
- Analyze attention weights
- Perform cross-validation
- Test on external datasets

---

**üöÄ The hybrid CNN-Transformer combines the best of both worlds for time-series classification!**