## 1Ô∏è‚É£ Setup Environment

In [None]:
# Check GPU availability
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

# Enable memory growth to avoid OOM
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print(f"‚úì GPU configured: {gpus[0].name}")

In [None]:
# Install required packages
!pip install -q mne scipy scikit-learn matplotlib seaborn

In [None]:
# Clone the repository (replace with your GitHub URL)
# Option A: Clone from GitHub
# !git clone https://github.com/YOUR_USERNAME/eeg-seizure-prediction.git
# %cd eeg-seizure-prediction

# Option B: Upload files manually
# Use the file browser on the left to upload your project files

# Option C: Mount Google Drive (if files are stored there)
from google.colab import drive
drive.mount('/content/drive')

# Uncomment and modify the path if your project is in Google Drive:
# %cd /content/drive/MyDrive/YourProjectFolder

## 2Ô∏è‚É£ Project Files Setup

Run this cell to create all necessary project files directly in Colab:

In [None]:
import os

# Create project structure
os.makedirs('data/raw', exist_ok=True)
os.makedirs('data/processed', exist_ok=True)
os.makedirs('saved_models', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.makedirs('training', exist_ok=True)
os.makedirs('utils', exist_ok=True)

print("‚úì Directory structure created")

In [None]:
%%writefile config.py
"""
Configuration Settings for EEG Seizure Prediction
"""
import os
from dataclasses import dataclass
from typing import Tuple

BASE_DIR = os.path.dirname(os.path.abspath(__file__)) if '__file__' in dir() else os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, "data", "raw")
PROCESSED_DIR = os.path.join(BASE_DIR, "data", "processed")
MODEL_DIR = os.path.join(BASE_DIR, "saved_models")
RESULTS_DIR = os.path.join(BASE_DIR, "results")

for dir_path in [DATA_DIR, PROCESSED_DIR, MODEL_DIR, RESULTS_DIR]:
    os.makedirs(dir_path, exist_ok=True)

@dataclass
class EEGConfig:
    original_sampling_rate: int = 256
    target_sampling_rate: int = 256
    n_channels: int = 22
    bandpass_low: float = 0.5
    bandpass_high: float = 40.0
    notch_freq: float = 60.0
    notch_width: float = 2.0
    window_duration: float = 10.0
    window_overlap: float = 0.5
    preictal_duration: int = 300
    seizure_prediction_horizon: int = 30

    @property
    def window_samples(self) -> int:
        return int(self.window_duration * self.target_sampling_rate)

@dataclass
class CNNConfig:
    conv_filters: Tuple[int, ...] = (64, 128, 256)
    conv_kernel_sizes: Tuple[int, ...] = (7, 5, 3)
    pool_sizes: Tuple[int, ...] = (2, 2, 2)
    activation: str = "relu"
    dropout_rate: float = 0.3
    use_batch_norm: bool = True

@dataclass
class TransformerConfig:
    d_model: int = 256
    n_heads: int = 8
    n_layers: int = 4
    d_ff: int = 512
    dropout_rate: float = 0.1
    attention_dropout: float = 0.1
    max_seq_length: int = 500

@dataclass
class ClassifierConfig:
    hidden_units: Tuple[int, ...] = (128, 64)
    dropout_rate: float = 0.5
    n_classes: int = 1

@dataclass
class TrainingConfig:
    train_ratio: float = 0.70
    val_ratio: float = 0.15
    test_ratio: float = 0.15
    batch_size: int = 32
    epochs: int = 100
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5
    optimizer: str = "adam"
    early_stopping_patience: int = 15
    use_class_weights: bool = True
    random_seed: int = 42

eeg_config = EEGConfig()
cnn_config = CNNConfig()
transformer_config = TransformerConfig()
classifier_config = ClassifierConfig()
training_config = TrainingConfig()

In [None]:
%%writefile models/cnn_encoder.py
"""CNN Encoder for EEG feature extraction"""
import tensorflow as tf
from tensorflow.keras import layers

class CNNEncoder(layers.Layer):
    def __init__(self, filters=(64, 128, 256), kernel_sizes=(7, 5, 3),
                 pool_sizes=(2, 2, 2), dropout_rate=0.3, **kwargs):
        super().__init__(**kwargs)
        self.conv_blocks = []
        for i, (f, k, p) in enumerate(zip(filters, kernel_sizes, pool_sizes)):
            self.conv_blocks.append([
                layers.Conv1D(f, k, padding='same', activation='relu'),
                layers.BatchNormalization(),
                layers.MaxPooling1D(p, padding='same')
            ])
        self.dropout = layers.Dropout(dropout_rate)

    def call(self, inputs, training=None):
        x = tf.transpose(inputs, perm=[0, 2, 1])  # (batch, time, channels)
        for conv, bn, pool in self.conv_blocks:
            x = conv(x)
            x = bn(x, training=training)
            x = pool(x)
        return self.dropout(x, training=training)

In [None]:
%%writefile models/transformer_encoder.py
"""Transformer Encoder for temporal modeling"""
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

class PositionalEncoding(layers.Layer):
    def __init__(self, max_len=500, d_model=256, **kwargs):
        super().__init__(**kwargs)
        pe = np.zeros((max_len, d_model))
        position = np.arange(0, max_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        self.pe = tf.constant(pe[np.newaxis, :, :], dtype=tf.float32)

    def call(self, x):
        return x + self.pe[:, :tf.shape(x)[1], :]

class TransformerEncoderLayer(layers.Layer):
    def __init__(self, d_model=256, n_heads=8, d_ff=512, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.mha = layers.MultiHeadAttention(n_heads, d_model // n_heads, dropout=dropout)
        self.ffn = tf.keras.Sequential([
            layers.Dense(d_ff, activation='relu'),
            layers.Dense(d_model)
        ])
        self.ln1 = layers.LayerNormalization()
        self.ln2 = layers.LayerNormalization()
        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)

    def call(self, x, training=None):
        attn = self.mha(x, x, training=training)
        x = self.ln1(x + self.dropout1(attn, training=training))
        ffn_out = self.ffn(x)
        return self.ln2(x + self.dropout2(ffn_out, training=training))

class TransformerEncoder(layers.Layer):
    def __init__(self, n_layers=4, d_model=256, n_heads=8, d_ff=512, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.pos_encoding = PositionalEncoding(d_model=d_model)
        self.enc_layers = [TransformerEncoderLayer(d_model, n_heads, d_ff, dropout)
                          for _ in range(n_layers)]

    def call(self, x, training=None):
        x = self.pos_encoding(x)
        for layer in self.enc_layers:
            x = layer(x, training=training)
        return x

In [None]:
%%writefile models/seizure_predictor.py
"""Main CNN + Transformer Seizure Prediction Model"""
import tensorflow as tf
from tensorflow.keras import layers, Model
from models.cnn_encoder import CNNEncoder
from models.transformer_encoder import TransformerEncoder

class ClassificationHead(layers.Layer):
    def __init__(self, hidden_units=(128, 64), dropout_rate=0.5, **kwargs):
        super().__init__(**kwargs)
        self.global_pool = layers.GlobalAveragePooling1D()
        self.dense_layers = [layers.Dense(u, activation='relu') for u in hidden_units]
        self.dropout_layers = [layers.Dropout(dropout_rate) for _ in hidden_units]
        self.output_layer = layers.Dense(1, activation='sigmoid')

    def call(self, x, training=None):
        x = self.global_pool(x)
        for dense, dropout in zip(self.dense_layers, self.dropout_layers):
            x = dense(x)
            x = dropout(x, training=training)
        return self.output_layer(x)

class SeizurePredictorCNNTransformer(Model):
    def __init__(self, n_channels=22, n_timesteps=2560,
                 conv_filters=(64, 128, 256), conv_kernel_sizes=(7, 5, 3),
                 pool_sizes=(2, 2, 2), cnn_dropout=0.3,
                 n_heads=8, n_transformer_layers=4, d_ff=512,
                 transformer_dropout=0.1, classifier_hidden=(128, 64),
                 classifier_dropout=0.5, **kwargs):
        super().__init__(**kwargs)
        self.n_channels = n_channels
        self.n_timesteps = n_timesteps

        self.cnn_encoder = CNNEncoder(conv_filters, conv_kernel_sizes,
                                      pool_sizes, cnn_dropout)
        d_model = conv_filters[-1]
        self.transformer = TransformerEncoder(n_transformer_layers, d_model,
                                              n_heads, d_ff, transformer_dropout)
        self.classifier = ClassificationHead(classifier_hidden, classifier_dropout)

    def call(self, inputs, training=None):
        x = self.cnn_encoder(inputs, training=training)
        x = self.transformer(x, training=training)
        return self.classifier(x, training=training)

    def build_model(self):
        dummy = tf.zeros((1, self.n_channels, self.n_timesteps))
        _ = self.call(dummy)
        return self

def create_model(n_channels=22, n_timesteps=2560):
    model = SeizurePredictorCNNTransformer(n_channels=n_channels, n_timesteps=n_timesteps)
    model.build_model()
    return model

In [None]:
%%writefile models/__init__.py
from models.seizure_predictor import SeizurePredictorCNNTransformer, create_model

In [None]:
%%writefile data/preprocessing.py
"""EEG Signal Preprocessing"""
import numpy as np
from scipy import signal

class EEGPreprocessor:
    def __init__(self, sampling_rate=256, bandpass_low=0.5, bandpass_high=40.0,
                 notch_freq=60.0, window_duration=10.0, window_overlap=0.5):
        self.sampling_rate = sampling_rate
        self.bandpass_low = bandpass_low
        self.bandpass_high = bandpass_high
        self.notch_freq = notch_freq
        self.window_duration = window_duration
        self.window_overlap = window_overlap
        self._init_filters()

    def _init_filters(self):
        nyquist = self.sampling_rate / 2
        low = max(0.001, min(self.bandpass_low / nyquist, 0.99))
        high = max(0.001, min(self.bandpass_high / nyquist, 0.99))
        self.bandpass_b, self.bandpass_a = signal.butter(4, [low, high], btype='band')
        notch_norm = self.notch_freq / nyquist
        if 0 < notch_norm < 1:
            self.notch_b, self.notch_a = signal.iirnotch(notch_norm, self.notch_freq / 2)
        else:
            self.notch_b, self.notch_a = None, None

    def bandpass_filter(self, data):
        return signal.filtfilt(self.bandpass_b, self.bandpass_a, data, axis=-1)

    def notch_filter(self, data):
        if self.notch_b is not None:
            return signal.filtfilt(self.notch_b, self.notch_a, data, axis=-1)
        return data

    def normalize(self, data):
        mean = np.mean(data, axis=-1, keepdims=True)
        std = np.std(data, axis=-1, keepdims=True) + 1e-8
        return (data - mean) / std

    def preprocess(self, data):
        data = self.bandpass_filter(data)
        data = self.notch_filter(data)
        data = self.normalize(data)
        return data

    def segment(self, data):
        window_samples = int(self.window_duration * self.sampling_rate)
        stride = int(window_samples * (1 - self.window_overlap))
        n_samples = data.shape[-1]
        windows = []
        for start in range(0, n_samples - window_samples + 1, stride):
            windows.append(data[..., start:start + window_samples])
        return np.array(windows)

In [None]:
%%writefile data/dataset.py
"""Dataset utilities"""
import numpy as np
from sklearn.model_selection import train_test_split

def create_synthetic_dataset(n_samples=1000, n_channels=22, n_timesteps=2560,
                             preictal_ratio=0.3, seed=42):
    """Create synthetic EEG data for testing"""
    np.random.seed(seed)
    n_preictal = int(n_samples * preictal_ratio)
    n_interictal = n_samples - n_preictal

    # Interictal: low amplitude random noise
    X_interictal = np.random.randn(n_interictal, n_channels, n_timesteps) * 0.5

    # Preictal: higher amplitude with spikes
    X_preictal = np.random.randn(n_preictal, n_channels, n_timesteps) * 1.5
    for i in range(n_preictal):
        n_spikes = np.random.randint(5, 15)
        for _ in range(n_spikes):
            pos = np.random.randint(0, n_timesteps - 50)
            channel = np.random.randint(0, n_channels)
            spike = np.exp(-np.linspace(0, 3, 50)) * np.random.uniform(3, 8)
            X_preictal[i, channel, pos:pos+50] += spike

    X = np.concatenate([X_interictal, X_preictal], axis=0).astype(np.float32)
    y = np.concatenate([np.zeros(n_interictal), np.ones(n_preictal)]).astype(np.float32)

    # Shuffle
    idx = np.random.permutation(n_samples)
    X, y = X[idx], y[idx]

    # Split
    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=seed)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=seed)

    return X_train, X_val, X_test, y_train, y_val, y_test

## 3Ô∏è‚É£ Quick Demo with Synthetic Data

Run this section to test the model without downloading the actual dataset:

In [None]:
import numpy as np
import tensorflow as tf
from config import eeg_config, training_config
from data.dataset import create_synthetic_dataset
from models.seizure_predictor import create_model

# Create synthetic dataset
print("Creating synthetic dataset...")
X_train, X_val, X_test, y_train, y_val, y_test = create_synthetic_dataset(
    n_samples=500,
    n_channels=22,
    n_timesteps=2560,
    preictal_ratio=0.3
)

print(f"Training set: {X_train.shape}, Labels: {y_train.shape}")
print(f"Validation set: {X_val.shape}, Labels: {y_val.shape}")
print(f"Test set: {X_test.shape}, Labels: {y_test.shape}")
print(f"Class distribution - Preictal: {y_train.sum()}/{len(y_train)} ({y_train.mean():.1%})")

In [None]:
# Create model
print("\nBuilding CNN + Transformer model...")
model = create_model(n_channels=22, n_timesteps=2560)
model.summary()

In [None]:
# Compile model
from sklearn.utils.class_weight import compute_class_weight

# Compute class weights for imbalanced data
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weight_dict = {0: class_weights[0], 1: class_weights[1]}
print(f"Class weights: {class_weight_dict}")

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
)

print("‚úì Model compiled")

In [None]:
# Train model
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-7)
]

print("\nTraining model...")
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=20,  # Reduced for demo
    batch_size=32,
    class_weight=class_weight_dict,
    callbacks=callbacks,
    verbose=1
)

In [None]:
# Evaluate model
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)

results = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Loss: {results[0]:.4f}")
print(f"Test Accuracy: {results[1]:.2%}")
print(f"Test AUC: {results[2]:.4f}")

# Predictions
y_pred_prob = model.predict(X_test, verbose=0).flatten()
y_pred = (y_pred_prob >= 0.5).astype(int)

from sklearn.metrics import classification_report, confusion_matrix
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['Interictal', 'Preictal']))

print("Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))

In [None]:
# Plot training history
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history.history['loss'], label='Train')
axes[0].plot(history.history['val_loss'], label='Validation')
axes[0].set_title('Loss')
axes[0].legend()

# Accuracy
axes[1].plot(history.history['accuracy'], label='Train')
axes[1].plot(history.history['val_accuracy'], label='Validation')
axes[1].set_title('Accuracy')
axes[1].legend()

# AUC
axes[2].plot(history.history['auc'], label='Train')
axes[2].plot(history.history['val_auc'], label='Validation')
axes[2].set_title('AUC')
axes[2].legend()

plt.tight_layout()
plt.show()

## 4Ô∏è‚É£ Save / Load Pretrained Model

In [None]:
# Save model
model.save('saved_models/cnn_transformer_colab.keras')
print("‚úì Model saved to saved_models/cnn_transformer_colab.keras")

# Download to local machine
from google.colab import files
files.download('saved_models/cnn_transformer_colab.keras')

In [None]:
# Load pretrained model
from models.seizure_predictor import SeizurePredictorCNNTransformer, ClassificationHead
from models.cnn_encoder import CNNEncoder
from models.transformer_encoder import TransformerEncoder, PositionalEncoding, TransformerEncoderLayer

custom_objects = {
    'SeizurePredictorCNNTransformer': SeizurePredictorCNNTransformer,
    'CNNEncoder': CNNEncoder,
    'TransformerEncoder': TransformerEncoder,
    'ClassificationHead': ClassificationHead,
    'PositionalEncoding': PositionalEncoding,
    'TransformerEncoderLayer': TransformerEncoderLayer
}

# Uncomment to load a pretrained model:
# loaded_model = tf.keras.models.load_model('saved_models/cnn_transformer_colab.keras',
#                                           custom_objects=custom_objects)
# print("‚úì Model loaded successfully")

## 5Ô∏è‚É£ Run Inference on New Data

In [None]:
def predict_seizure(model, eeg_data, threshold=0.5):
    """
    Predict seizure probability from EEG data.

    Args:
        model: Trained model
        eeg_data: EEG array of shape (n_channels, n_samples) or (batch, n_channels, n_samples)
        threshold: Classification threshold

    Returns:
        Dictionary with prediction results
    """
    # Add batch dimension if needed
    if eeg_data.ndim == 2:
        eeg_data = eeg_data[np.newaxis, ...]

    # Predict
    probabilities = model.predict(eeg_data, verbose=0).flatten()

    results = {
        'probabilities': probabilities,
        'predictions': ['Preictal' if p >= threshold else 'Interictal' for p in probabilities],
        'mean_probability': float(probabilities.mean()),
        'max_probability': float(probabilities.max()),
        'seizure_warning': probabilities.max() >= threshold
    }

    return results

# Example inference
print("Running inference on test samples...\n")
sample_data = X_test[:5]
results = predict_seizure(model, sample_data)

for i, (pred, prob) in enumerate(zip(results['predictions'], results['probabilities'])):
    actual = 'Preictal' if y_test[i] == 1 else 'Interictal'
    status = '‚úì' if pred == actual else '‚úó'
    print(f"Sample {i+1}: {pred} ({prob:.2%}) - Actual: {actual} {status}")

print(f"\n‚ö†Ô∏è Seizure Warning: {results['seizure_warning']}")

## 6Ô∏è‚É£ Download Real CHB-MIT Dataset (Optional)

Run this section to download actual EEG data from PhysioNet:

In [None]:
%%writefile data/download.py
"""CHB-MIT Dataset Download"""
import os
import re
import urllib.request
from dataclasses import dataclass
from typing import List, Dict, Tuple

PHYSIONET_BASE_URL = "https://physionet.org/files/chbmit/1.0.0"

@dataclass
class SeizureInfo:
    file_name: str
    start_time: int
    end_time: int

@dataclass
class PatientInfo:
    patient_id: str
    files: List[str]
    seizures: List[SeizureInfo]

def download_file(url, destination, verbose=True):
    try:
        if os.path.exists(destination):
            if verbose: print(f"  Exists: {os.path.basename(destination)}")
            return True
        os.makedirs(os.path.dirname(destination), exist_ok=True)
        if verbose: print(f"  Downloading: {os.path.basename(destination)}...")
        urllib.request.urlretrieve(url, destination)
        if verbose: print(f"  ‚úì Downloaded: {os.path.basename(destination)}")
        return True
    except Exception as e:
        print(f"  ‚úó Failed: {e}")
        return False

def parse_seizure_summary(summary_path):
    seizures, files = [], []
    with open(summary_path, 'r') as f:
        content = f.read()
    for block in re.split(r'File Name:', content)[1:]:
        lines = block.strip().split('\n')
        file_name = lines[0].strip()
        files.append(file_name)
        if 'Number of Seizures in File: 0' not in block:
            starts = re.findall(r'Seizure.*Start Time:\s*(\d+)', block)
            ends = re.findall(r'Seizure.*End Time:\s*(\d+)', block)
            for s, e in zip(starts, ends):
                seizures.append(SeizureInfo(file_name, int(s), int(e)))
    return seizures, files

def download_chb_mit_sample(patient_ids=None, data_dir='data/raw', max_files=5, verbose=True):
    if patient_ids is None:
        patient_ids = ['chb01', 'chb02', 'chb03']
    patients_info = {}
    for pid in patient_ids:
        if verbose: print(f"\nDownloading {pid}...")
        pdir = os.path.join(data_dir, pid)
        os.makedirs(pdir, exist_ok=True)
        summary_url = f"{PHYSIONET_BASE_URL}/{pid}/{pid}-summary.txt"
        summary_path = os.path.join(pdir, f"{pid}-summary.txt")
        if not download_file(summary_url, summary_path, verbose):
            continue
        seizures, file_list = parse_seizure_summary(summary_path)
        files_with_sz = set(s.file_name for s in seizures)
        priority = [f for f in file_list if f in files_with_sz]
        other = [f for f in file_list if f not in files_with_sz]
        to_download = (priority + other)[:max_files]
        downloaded, sz_downloaded = [], []
        for fname in to_download:
            url = f"{PHYSIONET_BASE_URL}/{pid}/{fname}"
            path = os.path.join(pdir, fname)
            if download_file(url, path, verbose):
                downloaded.append(fname)
                sz_downloaded.extend([s for s in seizures if s.file_name == fname])
        patients_info[pid] = PatientInfo(pid, downloaded, sz_downloaded)
        if verbose: print(f"  {len(downloaded)} files, {len(sz_downloaded)} seizures")
    return patients_info

In [None]:
# Download CHB-MIT data (uncomment to run)
from data.download import download_chb_mit_sample

# Download 3 patients with 5 files each (~500MB)
# patients = download_chb_mit_sample(
#     patient_ids=['chb01', 'chb02', 'chb03'],
#     max_files=5,
#     verbose=True
# )

print("Uncomment the code above to download real EEG data")

---

## üìù Summary

This notebook provides:

1. **Quick Demo** - Test with synthetic data (no downloads needed)
2. **Model Training** - Train CNN+Transformer on GPU
3. **Save/Load** - Export and import trained models
4. **Inference** - Make predictions on new EEG data
5. **Real Data** - Download CHB-MIT dataset from PhysioNet

### Next Steps:
- Upload your pretrained `.keras` files to load them
- Download more patients for better model performance
- Adjust hyperparameters in the config section

---
üß† **EEG Seizure Prediction System** | CNN + Transformer Architecture