# üß† CogniVue: TPU Training (Compatible with Preprocessing)

**‚úÖ FULLY COMPATIBLE with your preprocessing notebook!**

This notebook:
- ‚úÖ Loads data from single `.pkl` files per split
- ‚úÖ Handles variable channel count (typically 58 channels)
- ‚úÖ Correctly transposes X from `(n_ch, 256)` to `(256, n_ch)`
- ‚úÖ Uses TPU v5e-8 for fast training
- ‚úÖ Robust checkpointing and error handling
- ‚úÖ Resume training from interruptions

## üìã Setup Instructions

1. **Upload preprocessed data** as Kaggle dataset:
   - Your `data/processed/train/train_data.pkl`
   - Your `data/processed/val/val_data.pkl`  
   - Your `data/processed/test/test_data.pkl`

2. **Set accelerator** to **TPU VM v5e-8**

3. **Attach dataset** and update `DATASET_NAME` below

4. **Run All** and let it train!

## üì¶ 1. Install Dependencies

## ‚öôÔ∏è 2. Configuration

In [None]:
import os
import sys
import json
import time
import shutil
import numpy as np
import tensorflow as tf
import pickle
from datetime import datetime

print(f"TensorFlow version: {tf.__version__}")
print(f"Python version: {sys.version}")

# =====================================================
# PATHS CONFIGURATION
# =====================================================

# UPDATE THIS to match your dataset name!
DATASET_NAME = "preprocessed-cog-eeg-dataset"  

# Input paths
DATA_INPUT_DIR = f"/kaggle/input/{DATASET_NAME}/processed"

# Output paths
WORKING_DIR = "/kaggle/working"
CHECKPOINT_DIR = os.path.join(WORKING_DIR, "checkpoints")
RESULTS_DIR = os.path.join(WORKING_DIR, "results")
LOGS_DIR = os.path.join(WORKING_DIR, "logs")

for d in [CHECKPOINT_DIR, RESULTS_DIR, LOGS_DIR]:
    os.makedirs(d, exist_ok=True)

print(f"\nüìÇ Paths configured:")
print(f"  Input: {DATA_INPUT_DIR}")
print(f"  Checkpoints: {CHECKPOINT_DIR}")
print(f"  Results: {RESULTS_DIR}")

# =====================================================
# DATA CONSTANTS (from preprocessing)
# =====================================================

WINDOW_SIZE_SAMPLES = 256
NUM_BANDS = 5  # delta, theta, alpha, beta, gamma
NUM_TASKS = 4  # N-back, MATB-II, PVT, Flanker

# Output classes
NUM_OUTPUT_REGIONS = 7
NUM_OUTPUT_BANDS = 5
NUM_OUTPUT_STATES = 4

# Note: NUM_CHANNELS and NUM_OUTPUT_CHANNELS will be determined from data!

# =====================================================
# MODEL HYPERPARAMETERS
# =====================================================

D_MODEL = 256
NUM_LAYERS = 6
NUM_HEADS = 8
FF_DIM = 1024
DROPOUT = 0.15

BANDPOWER_HIDDEN_DIM = 128
BANDPOWER_OUTPUT_DIM = 128
TASK_EMBEDDING_DIM = 16

# =====================================================
# TRAINING HYPERPARAMETERS
# =====================================================

EPOCHS = 20
INITIAL_LR = 1e-4
WARMUP_EPOCHS = 5
WEIGHT_DECAY = 0.01
GRADIENT_CLIP_NORM = 1.0

SAVE_CHECKPOINT_EVERY = 2
EARLY_STOPPING_PATIENCE = 15

print(f"\nüîß Configuration:")
print(f"  Model: {NUM_LAYERS} layers, {NUM_HEADS} heads, D_MODEL={D_MODEL}")
print(f"  Training: {EPOCHS} epochs, LR={INITIAL_LR}")
print(f"  Checkpointing: every {SAVE_CHECKPOINT_EVERY} epochs")

## üîå 3. TPU Initialization

In [None]:
# ============================================================
# OFFICIAL KAGGLE TPU INITIALIZATION PATTERN
# Based on verified Kaggle TPU documentation and working examples
# ============================================================

import tensorflow as tf

print("=" * 70)
print("üîç KAGGLE TPU INITIALIZATION")
print("=" * 70)
print(f"\nTensorFlow version: {tf.__version__}")

# ============================================================
# OFFICIAL KAGGLE TPU INITIALIZATION (3 Methods - all work!)
# ============================================================

try:
    print("\nüì° Detecting TPU...")
    
    # METHOD 1: Auto-detect (RECOMMENDED for Kaggle)
    # This is the official pattern from multiple Kaggle tutorials
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f"   ‚úÖ TPU detected: {tpu.master()}")
    
    print("\nüîó Connecting to TPU cluster...")
    # CRITICAL: Must connect before initialization
    tf.config.experimental_connect_to_cluster(tpu)
    print("   ‚úÖ Connected to cluster")
    
    print("\n‚ö° Initializing TPU system...")
    # Initialize the TPU system
    tf.tpu.experimental.initialize_tpu_system(tpu)
    print("   ‚úÖ TPU system initialized")
    
    print("\nüéØ Creating TPU strategy...")
    # Create the distribution strategy
    strategy = tf.distribute.TPUStrategy(tpu)
    
    print("\n" + "=" * 70)
    print("‚úÖ TPU INITIALIZATION SUCCESSFUL!")
    print("=" * 70)
    
    # Verify replicas
    num_replicas = strategy.num_replicas_in_sync
    print(f"\nüöÄ TPU cores detected: {num_replicas}")
    
    if num_replicas == 8:
        print("   üéâ PERFECT! All 8 TPU cores active (TPU v5e-8 or v3-8)")
    elif num_replicas == 4:
        print("   ‚úÖ Good! 4 TPU cores active")
    elif num_replicas == 1:
        print("   ‚ö†Ô∏è  WARNING: Only 1 replica detected!")
        print("   üí° This means TPU is NOT properly configured")
        print("   üîß Fix: Settings ‚Üí Accelerator ‚Üí Select 'TPU VM v5e-8'")
    
    # Verify TPU devices
    print("\nüîç Verifying TPU devices...")
    tpu_devices = tf.config.list_logical_devices('TPU')
    print(f"   Found {len(tpu_devices)} TPU device(s)")
    if len(tpu_devices) > 0:
        for i, device in enumerate(tpu_devices[:3]):  # Show first 3
            print(f"   ‚Ä¢ {device.name}")
        if len(tpu_devices) > 3:
            print(f"   ‚Ä¢ ... and {len(tpu_devices) - 3} more")
    
    # Configure batch sizes
    BATCH_SIZE_PER_REPLICA = 64
    BATCH_SIZE = BATCH_SIZE_PER_REPLICA * num_replicas
    
    print(f"\nüìä Batch Configuration:")
    print(f"   Per-replica batch: {BATCH_SIZE_PER_REPLICA}")
    print(f"   Global batch size: {BATCH_SIZE} ({BATCH_SIZE_PER_REPLICA} √ó {num_replicas})")
    
    print("\n" + "=" * 70)
    print("üéØ READY FOR TRAINING!")
    print("=" * 70)
    print("\nüí° IMPORTANT: Wrap model creation in strategy scope:")
    print("   with strategy.scope():")
    print("       model = create_model()")
    print("       model.compile(...)")
    
    print("\n‚úÖ TPU initialization complete!")

except ValueError as ve:
    print(f"\n‚ùå TPU NOT FOUND: {ve}")
    print("\nüîç TROUBLESHOOTING CHECKLIST:")
    print("   ‚ùå Step 1: Click ‚öôÔ∏è Settings (top right)")
    print("   ‚ùå Step 2: Under 'Accelerator', select 'TPU VM v5e-8'")
    print("   ‚ùå Step 3: Click 'Save' and wait for kernel restart")
    print("   ‚ùå Step 4: Check top of notebook for 'TPU v5e-8: Xh remaining'")
    print("   ‚ùå Step 5: Verify you have TPU quota (need at least 1 hour)")
    
    print("\nüîÑ Falling back to CPU/GPU strategy...")
    strategy = tf.distribute.get_strategy()
    BATCH_SIZE = 32
    print(f"\n   Using: {strategy.__class__.__name__}")
    print(f"   Batch size: {BATCH_SIZE}")
    print("\n‚ö†Ô∏è  Training will be ~10-20x SLOWER without TPU!")

except Exception as e:
    print(f"\n‚ùå UNEXPECTED ERROR: {type(e).__name__}")
    print(f"   Details: {str(e)[:200]}")
    
    print("\nüîç Common causes:")
    print("   ‚Ä¢ TPU accelerator not selected in settings")
    print("   ‚Ä¢ TPU quota exhausted (check top of notebook)")
    print("   ‚Ä¢ Kernel needs restart after changing accelerator")
    print("   ‚Ä¢ This cell must run FIRST (before other TF code)")
    
    print("\nüîÑ Falling back to CPU/GPU strategy...")
    strategy = tf.distribute.get_strategy()
    BATCH_SIZE = 32
    print(f"\n   Using: {strategy.__class__.__name__}")
    print(f"   Batch size: {BATCH_SIZE}")

print("\n" + "=" * 70)
print("‚úÖ CONFIGURATION COMPLETE")
print("=" * 70)

In [None]:
# ============================================================
# AUTO-SAVE CALLBACK (Add this BEFORE training section)
# ============================================================

class AutoSaveCallback(tf.keras.callbacks.Callback):
    """Custom callback to create marker files for auto-committing"""
    
    def on_epoch_end(self, epoch, logs=None):
        # Create a marker file every epoch
        # Kaggle auto-commits when new files appear
        marker_path = f"/kaggle/working/progress_epoch_{epoch+1}.txt"
        with open(marker_path, 'w') as f:
            f.write(f"Completed epoch {epoch+1}/{EPOCHS}\n")
            f.write(f"Loss: {logs.get('loss', 0):.4f}\n")
            f.write(f"Val Loss: {logs.get('val_loss', 0):.4f}\n")
            f.write(f"Timestamp: {datetime.now()}\n")
        
        print(f"  üíæ Progress saved: progress_epoch_{epoch+1}.txt")

print("‚úÖ Auto-save callback ready!")

## üìä 4. Data Loading (Compatible with Preprocessing)

In [None]:
def load_preprocessed_data(split='train'):
    """
    Load preprocessed .pkl file (matches preprocessing output format).
    
    Returns:
        Tuple of (X, y, metadata) where:
        X = (X_eeg, X_bp, X_task)
        y = (y_channel, y_region, y_band, y_state)
        metadata = dict with num_channels, etc.
    """
    pkl_path = os.path.join(DATA_INPUT_DIR, split, f"{split}_data.pkl")
    
    print(f"\nüìÅ Loading {split} data from: {pkl_path}")
    
    if not os.path.exists(pkl_path):
        print(f"  ‚ùå File not found!")
        print(f"  üí° Check that dataset is attached and DATASET_NAME is correct")
        return None
    
    # Load pickle file
    try:
        with open(pkl_path, 'rb') as f:
            samples = pickle.load(f)
        
        print(f"  ‚úÖ Loaded {len(samples):,} samples")
        
        if len(samples) == 0:
            print(f"  ‚ùå No samples in file!")
            return None
        
        # Inspect first sample to get dimensions
        sample = samples[0]
        X_shape = sample['X'].shape  # Should be (n_channels, 256)
        bp_shape = sample['bp'].shape  # Should be (n_channels, 5)
        
        num_channels = X_shape[0]
        
        print(f"\n  üìä Data format:")
        print(f"     X shape: {X_shape} (channels, time)")
        print(f"     bp shape: {bp_shape} (channels, bands)")
        print(f"     Channels: {num_channels}")
        
        # Extract arrays
        print(f"\n  üîÑ Converting to arrays...")
        
        # X: Transpose from (n_channels, 256) to (256, n_channels)
        X_eeg = np.array([s['X'].T for s in samples], dtype=np.float32)
        
        # bp: Flatten from (n_channels, 5) to (n_channels*5,)
        X_bp = np.array([s['bp'].flatten() for s in samples], dtype=np.float32)
        
        # task_idx
        X_task = np.array([s['task_idx'] for s in samples], dtype=np.int32)
        
        # Labels
        y_channel = np.array([s['y_channel'] for s in samples], dtype=np.int32)
        y_region = np.array([s['y_region'] for s in samples], dtype=np.int32)
        y_band = np.array([s['y_band'] for s in samples], dtype=np.int32)
        y_state = np.array([s['y_state'] for s in samples], dtype=np.int32)
        
        print(f"  ‚úÖ Final shapes:")
        print(f"     X_eeg: {X_eeg.shape} (N, time, channels)")
        print(f"     X_bp: {X_bp.shape} (N, features)")
        print(f"     X_task: {X_task.shape}")
        print(f"     Labels: {y_channel.shape} each")
        
        metadata = {
            'num_channels': num_channels,
            'num_samples': len(samples),
            'bandpower_dim': X_bp.shape[1]
        }
        
        return (X_eeg, X_bp, X_task), (y_channel, y_region, y_band, y_state), metadata
        
    except Exception as e:
        print(f"  ‚ùå Error loading data: {e}")
        import traceback
        traceback.print_exc()
        return None

print("‚úÖ Data loading function defined")

## üèóÔ∏è 5. Model Architecture (Flexible Channels)

In [None]:
def create_model(num_channels, bandpower_input_dim, num_output_channels):
    """
    Create EEG Transformer with flexible channel dimensions.
    
    Args:
        num_channels: Number of EEG channels (e.g., 58)
        bandpower_input_dim: Bandpower feature dimension (num_channels * 5)
        num_output_channels: Number of output classes for channel prediction
    """
    # Inputs
    eeg_input = tf.keras.Input(shape=(WINDOW_SIZE_SAMPLES, num_channels), name='eeg')
    bp_input = tf.keras.Input(shape=(bandpower_input_dim,), name='bp')
    task_input = tf.keras.Input(shape=(1,), dtype='int32', name='task')
    
    # ==================== EEG STREAM ====================
    x = tf.keras.layers.Dense(D_MODEL, name='eeg_projection')(eeg_input)
    
    # Positional encoding
    positions = tf.range(start=0, limit=WINDOW_SIZE_SAMPLES, delta=1)
    pos_emb = tf.keras.layers.Embedding(
        input_dim=WINDOW_SIZE_SAMPLES,
        output_dim=D_MODEL,
        name='positional_embedding'
    )(positions)
    x = x + pos_emb
    
    # Transformer layers
    for i in range(NUM_LAYERS):
        attn = tf.keras.layers.MultiHeadAttention(
            num_heads=NUM_HEADS,
            key_dim=D_MODEL // NUM_HEADS,
            dropout=DROPOUT,
            name=f'mha_{i}'
        )(x, x)
        x = tf.keras.layers.Add(name=f'add_attn_{i}')([x, attn])
        x = tf.keras.layers.LayerNormalization(epsilon=1e-6, name=f'ln_attn_{i}')(x)
        
        ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(FF_DIM, activation='relu'),
            tf.keras.layers.Dense(D_MODEL),
            tf.keras.layers.Dropout(DROPOUT)
        ], name=f'ffn_{i}')
        
        ffn_out = ffn(x)
        x = tf.keras.layers.Add(name=f'add_ffn_{i}')([x, ffn_out])
        x = tf.keras.layers.LayerNormalization(epsilon=1e-6, name=f'ln_ffn_{i}')(x)
    
    eeg_emb = tf.keras.layers.GlobalAveragePooling1D(name='eeg_pool')(x)
    
    # ==================== BANDPOWER STREAM ====================
    bp_x = tf.keras.layers.Dense(BANDPOWER_HIDDEN_DIM, activation='relu', name='bp_hidden')(bp_input)
    bp_emb = tf.keras.layers.Dense(BANDPOWER_OUTPUT_DIM, activation='relu', name='bp_output')(bp_x)
    
    # ==================== TASK STREAM ====================
    task_emb = tf.keras.layers.Embedding(NUM_TASKS, TASK_EMBEDDING_DIM, name='task_emb')(task_input)
    task_emb = tf.keras.layers.Flatten(name='task_flatten')(task_emb)
    
    # ==================== FUSION ====================
    fused = tf.keras.layers.Concatenate(name='fusion')([eeg_emb, bp_emb, task_emb])
    
    # ==================== MULTI-TASK HEADS ====================
    out_channel = tf.keras.layers.Dense(num_output_channels, name='channel')(fused)
    out_region = tf.keras.layers.Dense(NUM_OUTPUT_REGIONS, name='region')(fused)
    out_band = tf.keras.layers.Dense(NUM_OUTPUT_BANDS, name='band')(fused)
    out_state = tf.keras.layers.Dense(NUM_OUTPUT_STATES, name='state')(fused)
    
    model = tf.keras.Model(
        inputs=[eeg_input, bp_input, task_input],
        outputs={
            'channel': out_channel,
            'region': out_region,
            'band': out_band,
            'state': out_state
        },
        name='CogniVue_Transformer'
    )
    
    return model

print("‚úÖ Model architecture defined")

## üìà 6. Learning Rate Schedule

In [None]:
class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_learning_rate, warmup_steps, total_steps):
        super().__init__()
        self.initial_learning_rate = initial_learning_rate
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
    
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        total_steps = tf.cast(self.total_steps, tf.float32)
        
        warmup_lr = (step / warmup_steps) * self.initial_learning_rate
        
        decay_steps = total_steps - warmup_steps
        decay_step = step - warmup_steps
        cosine_decay = 0.5 * (1 + tf.cos(tf.constant(np.pi) * decay_step / decay_steps))
        decay_lr = self.initial_learning_rate * cosine_decay
        
        return tf.cond(
            step < warmup_steps,
            lambda: warmup_lr,
            lambda: decay_lr
        )
    
    def get_config(self):
        return {
            "initial_learning_rate": self.initial_learning_rate,
            "warmup_steps": self.warmup_steps,
            "total_steps": self.total_steps,
        }

print("‚úÖ LR schedule defined")

## üîÑ 7. Data Pipeline & Callbacks

In [None]:
def create_tf_dataset(X, y, is_train=True):
    dataset = tf.data.Dataset.from_tensor_slices((
        {'eeg': X[0], 'bp': X[1], 'task': X[2]},
        {'channel': y[0], 'region': y[1], 'band': y[2], 'state': y[3]}
    ))
    
    if is_train:
        dataset = dataset.shuffle(10000)
    
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset


class PeriodicCheckpoint(tf.keras.callbacks.Callback):
    def __init__(self, save_freq=5):
        super().__init__()
        self.save_freq = save_freq
    
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.save_freq == 0:
            filepath = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1:03d}.keras")
            self.model.save(filepath)
            print(f"\n  üíæ Saved checkpoint: {os.path.basename(filepath)}")


class LearningRateLogger(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = self.model.optimizer.learning_rate
        if hasattr(lr, '__call__'):
            lr_value = lr(self.model.optimizer.iterations)
        else:
            lr_value = lr
        lr_float = float(tf.keras.backend.get_value(lr_value))
        if logs is not None:
            logs['learning_rate'] = lr_float
        print(f"\n  üìä LR = {lr_float:.6f}")


class ProgressLogger(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.epoch_start = None
    
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()
        print(f"\n{'='*70}")
        print(f"üìÖ Epoch {epoch+1}/{EPOCHS}")
        print(f"{'='*70}")
    
    def on_epoch_end(self, epoch, logs=None):
        elapsed = time.time() - self.epoch_start
        print(f"\n‚úÖ Epoch {epoch+1} done in {elapsed:.1f}s")
        if logs:
            print(f"   Loss: {logs.get('loss', 0):.4f} | Val Loss: {logs.get('val_loss', 0):.4f}")
            print(f"   Region Acc: {logs.get('region_accuracy', 0):.4f} | Val: {logs.get('val_region_accuracy', 0):.4f}")
        print(f"{'='*70}\n")

print("‚úÖ Pipeline & callbacks defined")

## üöÄ 8. Main Training Function

In [None]:
def train_cognivue():
    print("\n" + "="*70)
    print("üß† CogniVue Training Pipeline")
    print("="*70)
    
    # Load ONLY training data
    print("\nüìä Loading data...")
    train_result = load_preprocessed_data('train')
    
    if train_result is None:
        print("\n‚ùå Data loading failed!")
        return None
    
    train_data, train_labels, train_meta = train_result
    
    # =====================================================
    # SPLIT TRAIN DATA INTO TRAIN/VAL (80/20)
    # =====================================================
    print("\n‚úÇÔ∏è Splitting data into train/val...")
    
    num_samples = len(train_data[0])
    indices = np.arange(num_samples)
    np.random.seed(42)  # For reproducibility
    np.random.shuffle(indices)
    
    # 80% train, 20% val
    split_idx = int(0.8 * num_samples)
    train_idx = indices[:split_idx]
    val_idx = indices[split_idx:]
    
    # Split X data
    train_X = (
        train_data[0][train_idx],  # X_eeg
        train_data[1][train_idx],  # X_bp
        train_data[2][train_idx]   # X_task
    )
    val_X = (
        train_data[0][val_idx],
        train_data[1][val_idx],
        train_data[2][val_idx]
    )
    
    # Split y data
    train_y = (
        train_labels[0][train_idx],  # y_channel
        train_labels[1][train_idx],  # y_region
        train_labels[2][train_idx],  # y_band
        train_labels[3][train_idx]   # y_state
    )
    val_y = (
        train_labels[0][val_idx],
        train_labels[1][val_idx],
        train_labels[2][val_idx],
        train_labels[3][val_idx]
    )
    
    print(f"  ‚úÖ Train samples: {len(train_idx):,}")
    print(f"  ‚úÖ Val samples: {len(val_idx):,}")
    
    # Get dimensions from data
    NUM_CHANNELS = train_meta['num_channels']
    BANDPOWER_INPUT_DIM = train_meta['bandpower_dim']
    
    # Determine NUM_OUTPUT_CHANNELS from labels
    NUM_OUTPUT_CHANNELS = train_labels[0].max() + 1
    
    print(f"\nüîç Model dimensions:")
    print(f"   Input channels: {NUM_CHANNELS}")
    print(f"   Bandpower dim: {BANDPOWER_INPUT_DIM}")
    print(f"   Output channels: {NUM_OUTPUT_CHANNELS}")
    
    # Create datasets
    print(f"\nüîÑ Creating TF datasets...")
    train_ds = create_tf_dataset(train_X, train_y, is_train=True)
    val_ds = create_tf_dataset(val_X, val_y, is_train=False)
    
    steps_per_epoch = len(train_idx) // BATCH_SIZE
    total_steps = steps_per_epoch * EPOCHS
    warmup_steps = steps_per_epoch * WARMUP_EPOCHS
    
    print(f"   Steps/epoch: {steps_per_epoch:,}")
    print(f"   Total steps: {total_steps:,}")
    
    # Build model
    print(f"\nÔøΩÔ∏è Building model...")
    
    with strategy.scope():
        model = create_model(NUM_CHANNELS, BANDPOWER_INPUT_DIM, NUM_OUTPUT_CHANNELS)
        
        print(f"   Parameters: {model.count_params():,}")
        
        lr_schedule = WarmupCosineDecay(INITIAL_LR, warmup_steps, total_steps)
        
        optimizer = tf.keras.optimizers.AdamW(
            learning_rate=lr_schedule,
            weight_decay=WEIGHT_DECAY,
            clipnorm=GRADIENT_CLIP_NORM
        )
        
        loss_weights = {'channel': 0.4, 'region': 0.4, 'band': 0.1, 'state': 0.1}
        model.compile(
    optimizer=optimizer,
    loss={
        'channel': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        'region': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        'band': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        'state': tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    },
    loss_weights={
        'channel': 0.4,
        'region': 0.4,
        'band': 0.1,
        'state': 0.1
    },
    metrics={
        'channel': ['accuracy'],
        'region': ['accuracy'],
        'band': ['accuracy'],
        'state': ['accuracy']
    }
)
    print(f"   ‚úÖ Compiled")
    
    # Callbacks
    callbacks = [
        ProgressLogger(),
        LearningRateLogger(),
        tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(CHECKPOINT_DIR, 'best_model.keras'),
            monitor='val_loss',
            save_best_only=True,
            verbose=1
        ),
        PeriodicCheckpoint(save_freq=SAVE_CHECKPOINT_EVERY),
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=EARLY_STOPPING_PATIENCE,
            restore_best_weights=True,
            verbose=1
        ),
        tf.keras.callbacks.TensorBoard(log_dir=LOGS_DIR)
    ]
    
    # Train
    print(f"\nüöÄ Starting training...")
    print(f"‚è∞ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    
    try:
        history = model.fit(
            train_ds,
            epochs=EPOCHS,
            validation_data=val_ds,
            callbacks=callbacks,
            verbose=1
        )
        
        print("\n" + "="*70)
        print("‚úÖ TRAINING COMPLETE!")
        print("="*70)
        print(f"\nüíæ Saved to: {CHECKPOINT_DIR}")
        
        # Save final
        model.save(os.path.join(CHECKPOINT_DIR, 'final_model.keras'))
        print(f"   - final_model.keras")
        print(f"   - best_model.keras")
        
        return history
        
    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Training interrupted!")
        model.save(os.path.join(CHECKPOINT_DIR, 'interrupted.keras'))
        print(f"   üíæ Saved: interrupted.keras")
        return None
        
    except Exception as e:
        print(f"\n‚ùå Training failed: {e}")
        import traceback
        traceback.print_exc()
        return None

print("‚úÖ Training function ready")

## ‚ñ∂Ô∏è 9. Run Training

In [None]:
history = train_cognivue()

if history:
    print("\nüìä Training Summary:")
    print(f"   Best val_loss: {min(history.history['val_loss']):.4f}")
    print(f"   Final region acc: {history.history['region_accuracy'][-1]:.4f}")
    print(f"\nüí° Next: Click 'Save Version' to commit checkpoints!")
else:
    print("\n‚ö†Ô∏è Training did not complete successfully")

10: Save Training Results & Metrics

In [None]:
if history:
    print("\nüìù Saving training results...")
    
    # Convert history to JSON-serializable format
    history_dict = {
        key: [float(val) for val in values] 
        for key, values in history.history.items()
    }
    
    # Save training history
    history_path = os.path.join(RESULTS_DIR, 'training_history.json')
    with open(history_path, 'w') as f:
        json.dump(history_dict, f, indent=2)
    print(f"  ‚úÖ Saved: training_history.json")
    
    # Get model dimensions from loaded data
    train_result = load_preprocessed_data('train')
    if train_result:
        _, _, train_meta = train_result
        NUM_CHANNELS_USED = train_meta['num_channels']
    else:
        NUM_CHANNELS_USED = "unknown"
    
    # Save training configuration
    config = {
        'model_architecture': {
            'name': 'CogniVue_Transformer',
            'num_input_channels': NUM_CHANNELS_USED,
            'd_model': D_MODEL,
            'num_layers': NUM_LAYERS,
            'num_heads': NUM_HEADS,
            'ff_dim': FF_DIM,
            'dropout': DROPOUT,
            'window_size': WINDOW_SIZE_SAMPLES
        },
        'training_params': {
            'epochs_trained': len(history.history['loss']),
            'total_epochs': EPOCHS,
            'batch_size': BATCH_SIZE,
            'initial_lr': INITIAL_LR,
            'warmup_epochs': WARMUP_EPOCHS,
            'weight_decay': WEIGHT_DECAY,
            'gradient_clip_norm': GRADIENT_CLIP_NORM
        },
        'output_tasks': {
            'num_output_channels': NUM_OUTPUT_REGIONS,
            'num_regions': NUM_OUTPUT_REGIONS,
            'num_bands': NUM_OUTPUT_BANDS,
            'num_states': NUM_OUTPUT_STATES
        },
        'final_metrics': {
            'best_val_loss': float(min(history.history['val_loss'])),
            'final_train_loss': float(history.history['loss'][-1]),
            'final_val_loss': float(history.history['val_loss'][-1]),
            'final_region_accuracy': float(history.history['region_accuracy'][-1]),
            'final_val_region_accuracy': float(history.history['val_region_accuracy'][-1])
        },
        'training_info': {
            'completed_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
            'tensorflow_version': tf.__version__,
            'accelerator': 'TPU' if 'TPU' in str(strategy.__class__) else 'CPU/GPU'
        }
    }
    
    config_path = os.path.join(RESULTS_DIR, 'training_config.json')
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)
    print(f"  ‚úÖ Saved: training_config.json")
    
    # Create a summary markdown file
    summary_md = f"""# CogniVue Training Summary

**Training Completed:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## Model Architecture
- **Model:** CogniVue Transformer
- **Input Channels:** {NUM_CHANNELS_USED}
- **Model Dimension:** {D_MODEL}
- **Transformer Layers:** {NUM_LAYERS}
- **Attention Heads:** {NUM_HEADS}
- **Feedforward Dim:** {FF_DIM}
- **Dropout:** {DROPOUT}

## Training Configuration
- **Epochs:** {len(history.history['loss'])}/{EPOCHS}
- **Batch Size:** {BATCH_SIZE}
- **Initial LR:** {INITIAL_LR}
- **Warmup Epochs:** {WARMUP_EPOCHS}
- **Weight Decay:** {WEIGHT_DECAY}

## Final Performance
- **Best Val Loss:** {min(history.history['val_loss']):.4f}
- **Final Train Loss:** {history.history['loss'][-1]:.4f}
- **Final Val Loss:** {history.history['val_loss'][-1]:.4f}
- **Final Region Accuracy:** {history.history['region_accuracy'][-1]:.4f}
- **Final Val Region Accuracy:** {history.history['val_region_accuracy'][-1]:.4f}

## Output Files
- `checkpoints/best_model.keras` - Best model weights
- `checkpoints/final_model.keras` - Final model weights
- `checkpoints/checkpoint_epoch_*.keras` - Periodic checkpoints
- `results/training_history.json` - Loss and metrics per epoch
- `results/training_config.json` - Full configuration
- `logs/` - TensorBoard logs
"""
    
    summary_path = os.path.join(RESULTS_DIR, 'TRAINING_SUMMARY.md')
    with open(summary_path, 'w') as f:
        f.write(summary_md)
    print(f"  ‚úÖ Saved: TRAINING_SUMMARY.md")
    
    print("\n‚úÖ All results saved!")
    print(f"\nüìÇ Saved files:")
    print(f"   {RESULTS_DIR}/")
    print(f"   ‚îú‚îÄ‚îÄ training_history.json")
    print(f"   ‚îú‚îÄ‚îÄ training_config.json")
    print(f"   ‚îî‚îÄ‚îÄ TRAINING_SUMMARY.md")
    print(f"\n   {CHECKPOINT_DIR}/")
    print(f"   ‚îú‚îÄ‚îÄ best_model.keras")
    print(f"   ‚îú‚îÄ‚îÄ final_model.keras")
    print(f"   ‚îî‚îÄ‚îÄ checkpoint_epoch_*.keras")
    
else:
    print("\n‚ö†Ô∏è No training history to save")




11: Package & Download All Outputs

In [None]:

# ============================================================
# Section 11: Package & Download All Outputs
# ============================================================

import zipfile
from pathlib import Path

print("\nüì¶ Creating download package...")
print("=" * 70)

# Create zip filename with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
zip_filename = f"cognivue_training_outputs_{timestamp}.zip"
zip_path = os.path.join(WORKING_DIR, zip_filename)

try:
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        
        # Add all files from results directory
        print("\nüìä Adding results...")
        if os.path.exists(RESULTS_DIR):
            for file in os.listdir(RESULTS_DIR):
                file_path = os.path.join(RESULTS_DIR, file)
                if os.path.isfile(file_path):
                    arcname = os.path.join('results', file)
                    zipf.write(file_path, arcname)
                    print(f"  ‚úÖ {file}")
        
        # Add all checkpoint files
        print("\nüîñ Adding checkpoints...")
        if os.path.exists(CHECKPOINT_DIR):
            for file in os.listdir(CHECKPOINT_DIR):
                file_path = os.path.join(CHECKPOINT_DIR, file)
                if os.path.isfile(file_path):
                    arcname = os.path.join('checkpoints', file)
                    zipf.write(file_path, arcname)
                    file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
                    print(f"  ‚úÖ {file} ({file_size_mb:.1f} MB)")
        
        # Add TensorBoard logs (optional - can be large)
        print("\nüìà Adding TensorBoard logs...")
        if os.path.exists(LOGS_DIR):
            log_count = 0
            for root, dirs, files in os.walk(LOGS_DIR):
                for file in files:
                    file_path = os.path.join(root, file)
                    arcname = os.path.join('logs', os.path.relpath(file_path, LOGS_DIR))
                    zipf.write(file_path, arcname)
                    log_count += 1
            print(f"  ‚úÖ Added {log_count} log files")
    
    # Get final zip size
    zip_size_mb = os.path.getsize(zip_path) / (1024 * 1024)
    
    print("\n" + "=" * 70)
    print("‚úÖ PACKAGE CREATED SUCCESSFULLY!")
    print("=" * 70)
    print(f"\nüì¶ Zip file: {zip_filename}")
    print(f"üìè Size: {zip_size_mb:.1f} MB")
    print(f"üìç Location: {zip_path}")
    
    print("\nüì• To download:")
    print("   1. Go to the 'Output' tab (top right)")
    print("   2. Click 'Save Version' to commit outputs")
    print(f"   3. Download '{zip_filename}'")
    print("\nüí° Or click the download icon next to the file in the Output tab")
    
    # List contents
    print("\nüìã Package contents:")
    with zipfile.ZipFile(zip_path, 'r') as zipf:
        file_list = zipf.namelist()
        print(f"   Total files: {len(file_list)}")
        print("\n   Structure:")
        print("   ‚îú‚îÄ‚îÄ results/")
        print("   ‚îÇ   ‚îú‚îÄ‚îÄ training_history.json")
        print("   ‚îÇ   ‚îú‚îÄ‚îÄ training_config.json")
        print("   ‚îÇ   ‚îî‚îÄ‚îÄ TRAINING_SUMMARY.md")
        print("   ‚îú‚îÄ‚îÄ checkpoints/")
        print("   ‚îÇ   ‚îú‚îÄ‚îÄ best_model.keras")
        print("   ‚îÇ   ‚îú‚îÄ‚îÄ final_model.keras")
        print("   ‚îÇ   ‚îî‚îÄ‚îÄ checkpoint_epoch_*.keras")
        print("   ‚îî‚îÄ‚îÄ logs/")
        print("       ‚îî‚îÄ‚îÄ TensorBoard logs")
    
    print("\n" + "=" * 70)
    
except Exception as e:
    print(f"\n‚ùå Error creating zip: {e}")
    import traceback
    traceback.print_exc()