# üß† CogniVue: TPU Training (Compatible with Preprocessing)

**‚úÖ FULLY COMPATIBLE with your preprocessing notebook!**

This notebook:
- ‚úÖ Loads data from single `.pkl` files per split
- ‚úÖ Handles variable channel count (typically 58 channels)
- ‚úÖ Correctly transposes X from `(n_ch, 256)` to `(256, n_ch)`
- ‚úÖ Uses TPU v5e-8 for fast training
- ‚úÖ Robust checkpointing and error handling
- ‚úÖ Resume training from interruptions

## üìã Setup Instructions

1. **Upload preprocessed data** as Kaggle dataset:
   - Your `data/processed/train/train_data.pkl`
   - Your `data/processed/val/val_data.pkl`  
   - Your `data/processed/test/test_data.pkl`

2. **Set accelerator** to **TPU VM v5e-8**

3. **Attach dataset** and update `DATASET_NAME` below

4. **Run All** and let it train!

## üì¶ 1. Install Dependencies

In [None]:
!pip install -q tensorflow pandas numpy

## ‚öôÔ∏è 2. Configuration

In [None]:
import os
import sys
import json
import time
import shutil
import numpy as np
import tensorflow as tf
import pickle
from datetime import datetime

print(f"TensorFlow version: {tf.__version__}")
print(f"Python version: {sys.version}")

# =====================================================
# PATHS CONFIGURATION
# =====================================================

# UPDATE THIS to match your dataset name!
DATASET_NAME = "preprocessed-cog-eeg-dataset"  

# Input paths
DATA_INPUT_DIR = f"/kaggle/input/{DATASET_NAME}/data/processed"

# Output paths
WORKING_DIR = "/kaggle/working"
CHECKPOINT_DIR = os.path.join(WORKING_DIR, "checkpoints")
RESULTS_DIR = os.path.join(WORKING_DIR, "results")
LOGS_DIR = os.path.join(WORKING_DIR, "logs")

for d in [CHECKPOINT_DIR, RESULTS_DIR, LOGS_DIR]:
    os.makedirs(d, exist_ok=True)

print(f"\nüìÇ Paths configured:")
print(f"  Input: {DATA_INPUT_DIR}")
print(f"  Checkpoints: {CHECKPOINT_DIR}")
print(f"  Results: {RESULTS_DIR}")

# =====================================================
# DATA CONSTANTS (from preprocessing)
# =====================================================

WINDOW_SIZE_SAMPLES = 256
NUM_BANDS = 5  # delta, theta, alpha, beta, gamma
NUM_TASKS = 4  # N-back, MATB-II, PVT, Flanker

# Output classes
NUM_OUTPUT_REGIONS = 7
NUM_OUTPUT_BANDS = 5
NUM_OUTPUT_STATES = 4

# Note: NUM_CHANNELS and NUM_OUTPUT_CHANNELS will be determined from data!

# =====================================================
# MODEL HYPERPARAMETERS
# =====================================================

D_MODEL = 256
NUM_LAYERS = 6
NUM_HEADS = 8
FF_DIM = 1024
DROPOUT = 0.15

BANDPOWER_HIDDEN_DIM = 128
BANDPOWER_OUTPUT_DIM = 128
TASK_EMBEDDING_DIM = 16

# =====================================================
# TRAINING HYPERPARAMETERS
# =====================================================

EPOCHS = 100
INITIAL_LR = 1e-4
WARMUP_EPOCHS = 5
WEIGHT_DECAY = 0.01
GRADIENT_CLIP_NORM = 1.0

SAVE_CHECKPOINT_EVERY = 5
EARLY_STOPPING_PATIENCE = 15

print(f"\nüîß Configuration:")
print(f"  Model: {NUM_LAYERS} layers, {NUM_HEADS} heads, D_MODEL={D_MODEL}")
print(f"  Training: {EPOCHS} epochs, LR={INITIAL_LR}")
print(f"  Checkpointing: every {SAVE_CHECKPOINT_EVERY} epochs")

## üîå 3. TPU Initialization

In [None]:
try:
    print("üîç Detecting TPU...")
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
    tf.tpu.experimental.initialize_tpu_system(resolver)
    
    strategy = tf.distribute.TPUStrategy(resolver)
    
    print("\n‚úÖ TPU initialized!")
    print(f"  Address: {resolver.master()}")
    print(f"  Replicas: {strategy.num_replicas_in_sync}")
    
    BATCH_SIZE_PER_REPLICA = 64
    BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
    print(f"  Global batch size: {BATCH_SIZE}")
    
except Exception as e:
    print(f"\n‚ö†Ô∏è TPU init failed: {e}")
    print("   Falling back to CPU/GPU (slower)")
    
    strategy = tf.distribute.get_strategy()
    BATCH_SIZE_PER_REPLICA = 32
    BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
    print(f"  Using: {strategy.__class__.__name__}")
    print(f"  Batch size: {BATCH_SIZE}")

## üìä 4. Data Loading (Compatible with Preprocessing)

In [None]:
def load_preprocessed_data(split='train'):
    """
    Load preprocessed .pkl file (matches preprocessing output format).
    
    Returns:
        Tuple of (X, y, metadata) where:
        X = (X_eeg, X_bp, X_task)
        y = (y_channel, y_region, y_band, y_state)
        metadata = dict with num_channels, etc.
    """
    pkl_path = os.path.join(DATA_INPUT_DIR, split, f"{split}_data.pkl")
    
    print(f"\nüìÅ Loading {split} data from: {pkl_path}")
    
    if not os.path.exists(pkl_path):
        print(f"  ‚ùå File not found!")
        print(f"  üí° Check that dataset is attached and DATASET_NAME is correct")
        return None
    
    # Load pickle file
    try:
        with open(pkl_path, 'rb') as f:
            samples = pickle.load(f)
        
        print(f"  ‚úÖ Loaded {len(samples):,} samples")
        
        if len(samples) == 0:
            print(f"  ‚ùå No samples in file!")
            return None
        
        # Inspect first sample to get dimensions
        sample = samples[0]
        X_shape = sample['X'].shape  # Should be (n_channels, 256)
        bp_shape = sample['bp'].shape  # Should be (n_channels, 5)
        
        num_channels = X_shape[0]
        
        print(f"\n  üìä Data format:")
        print(f"     X shape: {X_shape} (channels, time)")
        print(f"     bp shape: {bp_shape} (channels, bands)")
        print(f"     Channels: {num_channels}")
        
        # Extract arrays
        print(f"\n  üîÑ Converting to arrays...")
        
        # X: Transpose from (n_channels, 256) to (256, n_channels)
        X_eeg = np.array([s['X'].T for s in samples], dtype=np.float32)
        
        # bp: Flatten from (n_channels, 5) to (n_channels*5,)
        X_bp = np.array([s['bp'].flatten() for s in samples], dtype=np.float32)
        
        # task_idx
        X_task = np.array([s['task_idx'] for s in samples], dtype=np.int32)
        
        # Labels
        y_channel = np.array([s['y_channel'] for s in samples], dtype=np.int32)
        y_region = np.array([s['y_region'] for s in samples], dtype=np.int32)
        y_band = np.array([s['y_band'] for s in samples], dtype=np.int32)
        y_state = np.array([s['y_state'] for s in samples], dtype=np.int32)
        
        print(f"  ‚úÖ Final shapes:")
        print(f"     X_eeg: {X_eeg.shape} (N, time, channels)")
        print(f"     X_bp: {X_bp.shape} (N, features)")
        print(f"     X_task: {X_task.shape}")
        print(f"     Labels: {y_channel.shape} each")
        
        metadata = {
            'num_channels': num_channels,
            'num_samples': len(samples),
            'bandpower_dim': X_bp.shape[1]
        }
        
        return (X_eeg, X_bp, X_task), (y_channel, y_region, y_band, y_state), metadata
        
    except Exception as e:
        print(f"  ‚ùå Error loading data: {e}")
        import traceback
        traceback.print_exc()
        return None

print("‚úÖ Data loading function defined")

## üèóÔ∏è 5. Model Architecture (Flexible Channels)

In [None]:
def create_model(num_channels, bandpower_input_dim, num_output_channels):
    """
    Create EEG Transformer with flexible channel dimensions.
    
    Args:
        num_channels: Number of EEG channels (e.g., 58)
        bandpower_input_dim: Bandpower feature dimension (num_channels * 5)
        num_output_channels: Number of output classes for channel prediction
    """
    # Inputs
    eeg_input = tf.keras.Input(shape=(WINDOW_SIZE_SAMPLES, num_channels), name='eeg')
    bp_input = tf.keras.Input(shape=(bandpower_input_dim,), name='bp')
    task_input = tf.keras.Input(shape=(1,), dtype='int32', name='task')
    
    # ==================== EEG STREAM ====================
    x = tf.keras.layers.Dense(D_MODEL, name='eeg_projection')(eeg_input)
    
    # Positional encoding
    positions = tf.range(start=0, limit=WINDOW_SIZE_SAMPLES, delta=1)
    pos_emb = tf.keras.layers.Embedding(
        input_dim=WINDOW_SIZE_SAMPLES,
        output_dim=D_MODEL,
        name='positional_embedding'
    )(positions)
    x = x + pos_emb
    
    # Transformer layers
    for i in range(NUM_LAYERS):
        attn = tf.keras.layers.MultiHeadAttention(
            num_heads=NUM_HEADS,
            key_dim=D_MODEL // NUM_HEADS,
            dropout=DROPOUT,
            name=f'mha_{i}'
        )(x, x)
        x = tf.keras.layers.Add(name=f'add_attn_{i}')([x, attn])
        x = tf.keras.layers.LayerNormalization(epsilon=1e-6, name=f'ln_attn_{i}')(x)
        
        ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(FF_DIM, activation='relu'),
            tf.keras.layers.Dense(D_MODEL),
            tf.keras.layers.Dropout(DROPOUT)
        ], name=f'ffn_{i}')
        
        ffn_out = ffn(x)
        x = tf.keras.layers.Add(name=f'add_ffn_{i}')([x, ffn_out])
        x = tf.keras.layers.LayerNormalization(epsilon=1e-6, name=f'ln_ffn_{i}')(x)
    
    eeg_emb = tf.keras.layers.GlobalAveragePooling1D(name='eeg_pool')(x)
    
    # ==================== BANDPOWER STREAM ====================
    bp_x = tf.keras.layers.Dense(BANDPOWER_HIDDEN_DIM, activation='relu', name='bp_hidden')(bp_input)
    bp_emb = tf.keras.layers.Dense(BANDPOWER_OUTPUT_DIM, activation='relu', name='bp_output')(bp_x)
    
    # ==================== TASK STREAM ====================
    task_emb = tf.keras.layers.Embedding(NUM_TASKS, TASK_EMBEDDING_DIM, name='task_emb')(task_input)
    task_emb = tf.keras.layers.Flatten(name='task_flatten')(task_emb)
    
    # ==================== FUSION ====================
    fused = tf.keras.layers.Concatenate(name='fusion')([eeg_emb, bp_emb, task_emb])
    
    # ==================== MULTI-TASK HEADS ====================
    out_channel = tf.keras.layers.Dense(num_output_channels, name='channel')(fused)
    out_region = tf.keras.layers.Dense(NUM_OUTPUT_REGIONS, name='region')(fused)
    out_band = tf.keras.layers.Dense(NUM_OUTPUT_BANDS, name='band')(fused)
    out_state = tf.keras.layers.Dense(NUM_OUTPUT_STATES, name='state')(fused)
    
    model = tf.keras.Model(
        inputs=[eeg_input, bp_input, task_input],
        outputs={
            'channel': out_channel,
            'region': out_region,
            'band': out_band,
            'state': out_state
        },
        name='CogniVue_Transformer'
    )
    
    return model

print("‚úÖ Model architecture defined")

## üìà 6. Learning Rate Schedule

In [None]:
class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_learning_rate, warmup_steps, total_steps):
        super().__init__()
        self.initial_learning_rate = initial_learning_rate
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
    
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        total_steps = tf.cast(self.total_steps, tf.float32)
        
        warmup_lr = (step / warmup_steps) * self.initial_learning_rate
        
        decay_steps = total_steps - warmup_steps
        decay_step = step - warmup_steps
        cosine_decay = 0.5 * (1 + tf.cos(tf.constant(np.pi) * decay_step / decay_steps))
        decay_lr = self.initial_learning_rate * cosine_decay
        
        return tf.cond(
            step < warmup_steps,
            lambda: warmup_lr,
            lambda: decay_lr
        )
    
    def get_config(self):
        return {
            "initial_learning_rate": self.initial_learning_rate,
            "warmup_steps": self.warmup_steps,
            "total_steps": self.total_steps,
        }

print("‚úÖ LR schedule defined")

## üîÑ 7. Data Pipeline & Callbacks

In [None]:
def create_tf_dataset(X, y, is_train=True):
    dataset = tf.data.Dataset.from_tensor_slices((
        {'eeg': X[0], 'bp': X[1], 'task': X[2]},
        {'channel': y[0], 'region': y[1], 'band': y[2], 'state': y[3]}
    ))
    
    if is_train:
        dataset = dataset.shuffle(10000)
    
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset


class PeriodicCheckpoint(tf.keras.callbacks.Callback):
    def __init__(self, save_freq=5):
        super().__init__()
        self.save_freq = save_freq
    
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.save_freq == 0:
            filepath = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1:03d}.keras")
            self.model.save(filepath)
            print(f"\n  üíæ Saved checkpoint: {os.path.basename(filepath)}")


class LearningRateLogger(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = self.model.optimizer.learning_rate
        if hasattr(lr, '__call__'):
            lr_value = lr(self.model.optimizer.iterations)
        else:
            lr_value = lr
        lr_float = float(tf.keras.backend.get_value(lr_value))
        if logs is not None:
            logs['learning_rate'] = lr_float
        print(f"\n  üìä LR = {lr_float:.6f}")


class ProgressLogger(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.epoch_start = None
    
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()
        print(f"\n{'='*70}")
        print(f"üìÖ Epoch {epoch+1}/{EPOCHS}")
        print(f"{'='*70}")
    
    def on_epoch_end(self, epoch, logs=None):
        elapsed = time.time() - self.epoch_start
        print(f"\n‚úÖ Epoch {epoch+1} done in {elapsed:.1f}s")
        if logs:
            print(f"   Loss: {logs.get('loss', 0):.4f} | Val Loss: {logs.get('val_loss', 0):.4f}")
            print(f"   Region Acc: {logs.get('region_accuracy', 0):.4f} | Val: {logs.get('val_region_accuracy', 0):.4f}")
        print(f"{'='*70}\n")

print("‚úÖ Pipeline & callbacks defined")

## üöÄ 8. Main Training Function

In [None]:
def train_cognivue():
    print("\n" + "="*70)
    print("üß† CogniVue Training Pipeline")
    print("="*70)
    
    # Load data
    print("\nüìä Loading data...")
    train_result = load_preprocessed_data('train')
    val_result = load_preprocessed_data('val')
    
    if train_result is None or val_result is None:
        print("\n‚ùå Data loading failed!")
        return None
    
    train_data, train_labels, train_meta = train_result
    val_data, val_labels, val_meta = val_result
    
    # Get dimensions from data
    NUM_CHANNELS = train_meta['num_channels']
    BANDPOWER_INPUT_DIM = train_meta['bandpower_dim']
    
    # Determine NUM_OUTPUT_CHANNELS from labels
    NUM_OUTPUT_CHANNELS = max(train_labels[0].max(), val_labels[0].max()) + 1
    
    print(f"\nüìê Model dimensions:")
    print(f"   Input channels: {NUM_CHANNELS}")
    print(f"   Bandpower dim: {BANDPOWER_INPUT_DIM}")
    print(f"   Output channels: {NUM_OUTPUT_CHANNELS}")
    
    # Create datasets
    print(f"\nüîÑ Creating TF datasets...")
    train_ds = create_tf_dataset(train_data, train_labels, is_train=True)
    val_ds = create_tf_dataset(val_data, val_labels, is_train=False)
    
    steps_per_epoch = train_meta['num_samples'] // BATCH_SIZE
    total_steps = steps_per_epoch * EPOCHS
    warmup_steps = steps_per_epoch * WARMUP_EPOCHS
    
    print(f"   Steps/epoch: {steps_per_epoch:,}")
    print(f"   Total steps: {total_steps:,}")
    
    # Build model
    print(f"\nüèóÔ∏è Building model...")
    
    with strategy.scope():
        model = create_model(NUM_CHANNELS, BANDPOWER_INPUT_DIM, NUM_OUTPUT_CHANNELS)
        
        print(f"   Parameters: {model.count_params():,}")
        
        lr_schedule = WarmupCosineDecay(INITIAL_LR, warmup_steps, total_steps)
        
        optimizer = tf.keras.optimizers.AdamW(
            learning_rate=lr_schedule,
            weight_decay=WEIGHT_DECAY,
            clipnorm=GRADIENT_CLIP_NORM
        )
        
        loss_weights = {'channel': 0.4, 'region': 0.4, 'band': 0.1, 'state': 0.1}
        
        model.compile(
            optimizer=optimizer,
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            loss_weights=loss_weights,
            metrics=['accuracy']
        )
    
    print(f"   ‚úÖ Compiled")
    
    # Callbacks
    callbacks = [
        ProgressLogger(),
        LearningRateLogger(),
        tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(CHECKPOINT_DIR, 'best_model.keras'),
            monitor='val_loss',
            save_best_only=True,
            verbose=1
        ),
        PeriodicCheckpoint(save_freq=SAVE_CHECKPOINT_EVERY),
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=EARLY_STOPPING_PATIENCE,
            restore_best_weights=True,
            verbose=1
        ),
        tf.keras.callbacks.TensorBoard(log_dir=LOGS_DIR)
    ]
    
    # Train
    print(f"\nüöÄ Starting training...")
    print(f"‚è∞ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    
    try:
        history = model.fit(
            train_ds,
            epochs=EPOCHS,
            validation_data=val_ds,
            callbacks=callbacks,
            verbose=1
        )
        
        print("\n" + "="*70)
        print("‚úÖ TRAINING COMPLETE!")
        print("="*70)
        print(f"\nüíæ Saved to: {CHECKPOINT_DIR}")
        
        # Save final
        model.save(os.path.join(CHECKPOINT_DIR, 'final_model.keras'))
        print(f"   - final_model.keras")
        print(f"   - best_model.keras")
        
        return history
        
    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Training interrupted!")
        model.save(os.path.join(CHECKPOINT_DIR, 'interrupted.keras'))
        print(f"   üíæ Saved: interrupted.keras")
        return None
        
    except Exception as e:
        print(f"\n‚ùå Training failed: {e}")
        import traceback
        traceback.print_exc()
        return None

print("‚úÖ Training function ready")

## ‚ñ∂Ô∏è 9. Run Training

In [None]:
history = train_cognivue()

if history:
    print("\nüìä Training Summary:")
    print(f"   Best val_loss: {min(history.history['val_loss']):.4f}")
    print(f"   Final region acc: {history.history['region_accuracy'][-1]:.4f}")
    print(f"\nüí° Next: Click 'Save Version' to commit checkpoints!")
else:
    print("\n‚ö†Ô∏è Training did not complete successfully")