#  CogniVue: Training

** FULLY COMPATIBLE with your preprocessing notebook!**

This notebook:
-  Loads data from single `.pkl` files per split
-  Handles variable channel count (typically 58 channels)
-  Correctly transposes X from `(n_ch, 256)` to `(256, n_ch)`
-  Robust checkpointing and error handling
-  Resume training from interruptions
-  



##  1. Configuration

In [None]:


import os
import sys
import json
import time
import shutil
import numpy as np
import tensorflow as tf
import pickle
from datetime import datetime
from tensorflow.keras import regularizers

print(f"TensorFlow version: {tf.__version__}")
print(f"Python version: {sys.version}")

# =====================================================
# PATHS CONFIGURATION
# =====================================================

# UPDATE THIS to match your dataset name!
DATASET_NAME = "preprocessed-cog-eeg-dataset"  

# Input paths
DATA_INPUT_DIR = f"/kaggle/input/{DATASET_NAME}/processed"

# Output paths
WORKING_DIR = "/kaggle/working"
CHECKPOINT_DIR = os.path.join(WORKING_DIR, "checkpoints")
RESULTS_DIR = os.path.join(WORKING_DIR, "results")
LOGS_DIR = os.path.join(WORKING_DIR, "logs")

for d in [CHECKPOINT_DIR, RESULTS_DIR, LOGS_DIR]:
    os.makedirs(d, exist_ok=True)

print(f"\n‚úÖ Paths configured:")
print(f"  Input: {DATA_INPUT_DIR}")
print(f"  Checkpoints: {CHECKPOINT_DIR}")
print(f"  Results: {RESULTS_DIR}")

# =====================================================
# DATA CONSTANTS (from preprocessing)
# =====================================================

WINDOW_SIZE_SAMPLES = 256
NUM_BANDS = 5  # delta, theta, alpha, beta, gamma
NUM_TASKS = 4  # N-back, MATB-II, PVT, Flanker

# Output classes
NUM_OUTPUT_REGIONS = 7
NUM_OUTPUT_BANDS = 5
NUM_OUTPUT_STATES = 4

# Note: NUM_CHANNELS and NUM_OUTPUT_CHANNELS will be determined from data!

# =====================================================
# ‚úÖ CORRECTED MODEL HYPERPARAMETERS
# =====================================================

# ‚úÖ REDUCED MODEL CAPACITY (from 256/6/1024)
D_MODEL = 128           # ‚úÖ CHANGED: Reduced from 256
NUM_LAYERS = 4          # ‚úÖ CHANGED: Reduced from 6
NUM_HEADS = 8           # ‚úÖ UNCHANGED
FF_DIM = 512            # ‚úÖ CHANGED: Reduced from 1024
DROPOUT = 0.3           # ‚úÖ CHANGED: Increased from 0.15

BANDPOWER_HIDDEN_DIM = 128
BANDPOWER_OUTPUT_DIM = 128
TASK_EMBEDDING_DIM = 16

# =====================================================
# ‚úÖ CORRECTED TRAINING HYPERPARAMETERS
# =====================================================

EPOCHS = 100
BATCH_SIZE = 32         # ‚úÖ CHANGED: Reduced from 64 for better regularization
INITIAL_LR = 5e-5       # ‚úÖ CHANGED: Reduced from 1e-4
WARMUP_EPOCHS = 10
WEIGHT_DECAY = 0.01
GRADIENT_CLIP_NORM = 1.0

SAVE_CHECKPOINT_EVERY = 2
EARLY_STOPPING_PATIENCE = 15  # ‚úÖ NEW: For early stopping

# ‚úÖ NEW: Regularization parameters
L2_REGULARIZATION = 0.01      # ‚úÖ NEW: L2 penalty factor

print(f"\nüîß ‚úÖ CORRECTED Configuration:")
print(f"  Model: {NUM_LAYERS} layers, {NUM_HEADS} heads, D_MODEL={D_MODEL}")
print(f"  Training: {EPOCHS} epochs, LR={INITIAL_LR}, Batch={BATCH_SIZE}")
print(f"  Regularization: Dropout={DROPOUT}, L2={L2_REGULARIZATION}")
print(f"  Checkpointing: every {SAVE_CHECKPOINT_EVERY} epochs")
print(f"  Early stopping patience: {EARLY_STOPPING_PATIENCE} epochs")


##  2. Initialization

In [None]:

# Detect and setup accelerator
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f"\nüî• TPU detected: {tpu.cluster_spec().as_dict()['worker']}")
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
    print(f"‚úÖ Running on TPU with {strategy.num_replicas_in_sync} cores")
except ValueError:
    # Check for GPUs
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        print(f"\nüéÆ GPUs detected: {len(gpus)} GPU(s)")
        for i, gpu in enumerate(gpus):
            print(f"  GPU {i}: {gpu}")
        
        # ‚úÖ Configure GPU memory growth to prevent OOM errors
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            print("‚úÖ GPU memory growth enabled")
        except RuntimeError as e:
            print(f"‚ö†Ô∏è Could not set memory growth: {e}")
        
        # ‚úÖ Use MirroredStrategy for multi-GPU training
        if len(gpus) > 1:
            strategy = tf.distribute.MirroredStrategy()
            print(f"‚úÖ Running on {len(gpus)} GPUs with MirroredStrategy")
            print(f"   Devices: {strategy.extended.worker_devices}")
        else:
            strategy = tf.distribute.get_strategy()  # Default strategy for single GPU
            print("‚úÖ Running on single GPU")
    else:
        print("\nüíª No GPU/TPU detected, running on CPU")
        strategy = tf.distribute.get_strategy()  # Default strategy

print(f"\nüìä Strategy info:")
print(f"  Number of replicas: {strategy.num_replicas_in_sync}")
print(f"  Effective batch size: {BATCH_SIZE * strategy.num_replicas_in_sync}")

##  3.  Data Loading function

In [None]:



def load_preprocessed_data(split='train'):
    """
    Load preprocessed .pkl file (matches preprocessing output format).
    
    Returns:
        Tuple of (X, y, metadata) where:
        X = (X_eeg, X_bp, X_task)
        y = (y_channel, y_region, y_band, y_state)
        metadata = dict with num_channels, etc.
    """
    pkl_path = os.path.join(DATA_INPUT_DIR, split, f"{split}_data.pkl")
    
    print(f"\nüìÇ Loading {split} data from: {pkl_path}")
    
    if not os.path.exists(pkl_path):
        print(f"   ‚ùå File not found!")
        print(f"   Check that dataset is attached and DATASET_NAME is correct")
        return None
    
    # Load pickle file
    try:
        with open(pkl_path, 'rb') as f:
            samples = pickle.load(f)
        
        print(f"   ‚úÖ Loaded {len(samples):,} samples")
        
        if len(samples) == 0:
            print(f"   ‚ùå No samples in file!")
            return None
        
        # Inspect first sample to get dimensions
        sample = samples[0]
        X_shape = sample['X'].shape  # Should be (n_channels, 256)
        bp_shape = sample['bp'].shape  # Should be (n_channels, 5)
        
        num_channels = X_shape[0]
        
        print(f"\n  üìä Data format:")
        print(f"     X shape: {X_shape} (channels, time)")
        print(f"     bp shape: {bp_shape} (channels, bands)")
        print(f"     Channels: {num_channels}")
        
        # Extract arrays
        print(f"\n   üîÑ Converting to arrays...")
        
        # X: Transpose from (n_channels, 256) to (256, n_channels)
        X_eeg = np.array([s['X'].T for s in samples], dtype=np.float32)
        
        # bp: Flatten from (n_channels, 5) to (n_channels*5,)
        X_bp = np.array([s['bp'].flatten() for s in samples], dtype=np.float32)
        
        # task_idx
        X_task = np.array([s['task_idx'] for s in samples], dtype=np.int32)
        
        # Labels
        y_channel = np.array([s['y_channel'] for s in samples], dtype=np.int32)
        y_region = np.array([s['y_region'] for s in samples], dtype=np.int32)
        y_band = np.array([s['y_band'] for s in samples], dtype=np.int32)
        y_state = np.array([s['y_state'] for s in samples], dtype=np.int32)
        
        print(f"   ‚úÖ Final shapes:")
        print(f"     X_eeg: {X_eeg.shape} (N, time, channels)")
        print(f"     X_bp: {X_bp.shape} (N, features)")
        print(f"     X_task: {X_task.shape}")
        print(f"     Labels: {y_channel.shape} each")
        
        metadata = {
            'num_channels': num_channels,
            'num_samples': len(samples),
            'bandpower_dim': X_bp.shape[1],
            'num_output_channels': len(np.unique(y_channel))
        }
        
        return (X_eeg, X_bp, X_task), (y_channel, y_region, y_band, y_state), metadata
        
    except Exception as e:
        print(f"  ‚ùå Error loading data: {e}")
        import traceback
        traceback.print_exc()
        return None

print("‚úÖ Data loading function defined")



##  3.  Model Architecture

In [None]:



def create_model(num_channels, bandpower_input_dim, num_output_channels):
    """
    ‚úÖ CORRECTED: EEG Transformer with aggressive regularization.
    
    CHANGES:
    - ‚úÖ Added L2 regularization to all Dense layers
    - ‚úÖ Increased dropout from 0.15 to 0.3
    - ‚úÖ Added task-specific dropout rates
    - ‚úÖ Deeper channel prediction head (3 layers instead of 1)
    - ‚úÖ Reduced model capacity (D_MODEL: 128, LAYERS: 4, FF_DIM: 512)
    - ‚úÖ Extra dropout after attention and pooling layers
    
    Args:
        num_channels: Number of EEG channels (e.g., 58)
        bandpower_input_dim: Bandpower feature dimension (num_channels * 5)
        num_output_channels: Number of output classes for channel prediction
    """
    
    # ‚úÖ L2 regularizer for all Dense layers
    l2_reg = regularizers.l2(L2_REGULARIZATION)
    
    print(f"\nüèóÔ∏è  Building model:")
    print(f"   Input channels: {num_channels}")
    print(f"   Bandpower dim: {bandpower_input_dim}")
    print(f"   Output channels: {num_output_channels}")
    print(f"   D_MODEL: {D_MODEL}, Layers: {NUM_LAYERS}, Dropout: {DROPOUT}")
    
    # ==================== INPUTS ====================
    eeg_input = tf.keras.Input(shape=(WINDOW_SIZE_SAMPLES, num_channels), name='eeg')
    bp_input = tf.keras.Input(shape=(bandpower_input_dim,), name='bp')
    task_input = tf.keras.Input(shape=(1,), dtype='int32', name='task')
    
    # ==================== EEG STREAM ====================
    # Initial projection with L2 reg
    x = tf.keras.layers.Dense(D_MODEL, kernel_regularizer=l2_reg, 
                              name='eeg_projection')(eeg_input)
    x = tf.keras.layers.Dropout(DROPOUT)(x)  # ‚úÖ NEW: Dropout after projection
    
    # Positional encoding with L2 reg
    positions = tf.range(start=0, limit=WINDOW_SIZE_SAMPLES, delta=1)
    pos_emb = tf.keras.layers.Embedding(
        input_dim=WINDOW_SIZE_SAMPLES,
        output_dim=D_MODEL,
        embeddings_regularizer=l2_reg,  # ‚úÖ NEW: L2 on embeddings
        name='positional_embedding'
    )(positions)
    x = x + pos_emb
    
    # ‚úÖ Transformer layers with enhanced regularization
    for i in range(NUM_LAYERS):
        # Multi-head attention
        attn = tf.keras.layers.MultiHeadAttention(
            num_heads=NUM_HEADS,
            key_dim=D_MODEL // NUM_HEADS,
            dropout=DROPOUT,
            name=f'mha_{i}'
        )(x, x)
        
        attn = tf.keras.layers.Dropout(DROPOUT)(attn)  # ‚úÖ NEW: Extra dropout
        x = tf.keras.layers.Add(name=f'add_attn_{i}')([x, attn])
        x = tf.keras.layers.LayerNormalization(epsilon=1e-6, 
                                               name=f'ln_attn_{i}')(x)
        
        # ‚úÖ Feedforward network with L2 reg and dropout
        ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(FF_DIM, activation='relu',
                                 kernel_regularizer=l2_reg),
            tf.keras.layers.Dropout(DROPOUT),  # ‚úÖ NEW: Dropout in FFN
            tf.keras.layers.Dense(D_MODEL, kernel_regularizer=l2_reg),
            tf.keras.layers.Dropout(DROPOUT)   # ‚úÖ NEW: Dropout in FFN
        ], name=f'ffn_{i}')
        
        ffn_out = ffn(x)
        x = tf.keras.layers.Add(name=f'add_ffn_{i}')([x, ffn_out])
        x = tf.keras.layers.LayerNormalization(epsilon=1e-6, 
                                               name=f'ln_ffn_{i}')(x)
    
    eeg_emb = tf.keras.layers.GlobalAveragePooling1D(name='eeg_pool')(x)
    eeg_emb = tf.keras.layers.Dropout(0.3)(eeg_emb)  # ‚úÖ NEW: Dropout after pooling
    
    # ==================== BANDPOWER STREAM ====================
    bp_x = tf.keras.layers.Dense(BANDPOWER_HIDDEN_DIM, activation='relu',
                                  kernel_regularizer=l2_reg,  # ‚úÖ NEW: L2 reg
                                  name='bp_hidden')(bp_input)
    bp_x = tf.keras.layers.Dropout(0.3)(bp_x)  # ‚úÖ NEW: Dropout
    bp_emb = tf.keras.layers.Dense(BANDPOWER_OUTPUT_DIM, activation='relu',
                                    kernel_regularizer=l2_reg,  # ‚úÖ NEW: L2 reg
                                    name='bp_output')(bp_x)
    bp_emb = tf.keras.layers.Dropout(0.3)(bp_emb)  # ‚úÖ NEW: Dropout
    
    # ==================== TASK STREAM ====================
    task_emb = tf.keras.layers.Embedding(NUM_TASKS, TASK_EMBEDDING_DIM,
                                         embeddings_regularizer=l2_reg,  # ‚úÖ NEW: L2 reg
                                         name='task_emb')(task_input)
    task_emb = tf.keras.layers.Flatten(name='task_flatten')(task_emb)
    
    # ==================== FUSION ====================
    fused = tf.keras.layers.Concatenate(name='fusion')([eeg_emb, bp_emb, task_emb])
    
    # ==================== ‚úÖ MULTI-TASK HEADS WITH TASK-SPECIFIC DROPOUT ====================
    
    # ‚úÖ CHANNEL HEAD - Highest dropout + deeper architecture for hardest task
    print(f"   Building channel head: dropout=0.5, 3-layer deep architecture")
    fused_channel = tf.keras.layers.Dropout(0.5)(fused)  # ‚úÖ NEW: 50% dropout
    channel_hidden = tf.keras.layers.Dense(512, activation='relu',
                                           kernel_regularizer=l2_reg)(fused_channel)
    channel_hidden = tf.keras.layers.Dropout(0.4)(channel_hidden)
    channel_hidden = tf.keras.layers.Dense(256, activation='relu',
                                           kernel_regularizer=l2_reg)(channel_hidden)
    channel_hidden = tf.keras.layers.Dropout(0.3)(channel_hidden)
    out_channel = tf.keras.layers.Dense(num_output_channels,
                                        kernel_regularizer=l2_reg,
                                        name='channel')(channel_hidden)
    
    # ‚úÖ REGION HEAD - Moderate dropout
    fused_region = tf.keras.layers.Dropout(0.3)(fused)  # ‚úÖ NEW
    out_region = tf.keras.layers.Dense(NUM_OUTPUT_REGIONS, 
                                       kernel_regularizer=l2_reg,  # ‚úÖ NEW
                                       name='region')(fused_region)
    
    # ‚úÖ BAND HEAD - Lower dropout (performs well already)
    fused_band = tf.keras.layers.Dropout(0.2)(fused)  # ‚úÖ NEW
    out_band = tf.keras.layers.Dense(NUM_OUTPUT_BANDS,
                                     kernel_regularizer=l2_reg,  # ‚úÖ NEW
                                     name='band')(fused_band)
    
    # ‚úÖ STATE HEAD - Minimal dropout (too easy)
    fused_state = tf.keras.layers.Dropout(0.1)(fused)  # ‚úÖ NEW
    out_state = tf.keras.layers.Dense(NUM_OUTPUT_STATES,
                                      kernel_regularizer=l2_reg,  # ‚úÖ NEW
                                      name='state')(fused_state)
    
    # ==================== MODEL ASSEMBLY ====================
    model = tf.keras.Model(
        inputs=[eeg_input, bp_input, task_input],
        outputs={
            'channel': out_channel,
            'region': out_region,
            'band': out_band,
            'state': out_state
        },
        name='CogniVue_Transformer_Corrected'
    )
    
    print(f"   ‚úÖ Model built successfully")
    print(f"   Total parameters: {model.count_params():,}")
    
    return model

print("‚úÖ Corrected model architecture defined")

##  5. Learning Rate Schedule

In [None]:

class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    """Warmup + Cosine decay learning rate schedule"""
    def __init__(self, initial_learning_rate, warmup_steps, total_steps):
        super().__init__()
        self.initial_learning_rate = initial_learning_rate
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
    
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        total_steps = tf.cast(self.total_steps, tf.float32)
        
        warmup_lr = (step / warmup_steps) * self.initial_learning_rate
        
        decay_steps = total_steps - warmup_steps
        decay_step = step - warmup_steps
        cosine_decay = 0.5 * (1 + tf.cos(tf.constant(np.pi) * decay_step / decay_steps))
        decay_lr = self.initial_learning_rate * cosine_decay
        
        return tf.cond(
            step < warmup_steps,
            lambda: warmup_lr,
            lambda: decay_lr
        )
    
    def get_config(self):
        return {
            "initial_learning_rate": self.initial_learning_rate,
            "warmup_steps": self.warmup_steps,
            "total_steps": self.total_steps,
        }

print("‚úÖ Learning rate schedule defined")


##  6. Data Pipelines and CallBacks

In [None]:



def create_tf_dataset(X, y, is_train=True):
    """Create TensorFlow dataset with proper batching"""
    dataset = tf.data.Dataset.from_tensor_slices((
        {'eeg': X[0], 'bp': X[1], 'task': X[2]},
        {'channel': y[0], 'region': y[1], 'band': y[2], 'state': y[3]}
    ))
    
    if is_train:
        dataset = dataset.shuffle(10000)
    
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    return dataset


class PeriodicCheckpoint(tf.keras.callbacks.Callback):
    """Save model every N epochs"""
    def __init__(self, save_freq=2):
        super().__init__()
        self.save_freq = save_freq
    
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.save_freq == 0:
            filepath = os.path.join(CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1:03d}.keras")
            self.model.save(filepath)
            print(f"\n   üíæ Saved checkpoint: {os.path.basename(filepath)}")


class LearningRateLogger(tf.keras.callbacks.Callback):
    """Log learning rate each epoch"""
    def on_epoch_end(self, epoch, logs=None):
        lr = self.model.optimizer.learning_rate
        if hasattr(lr, '__call__'):
            lr_value = lr(self.model.optimizer.iterations)
        else:
            lr_value = lr
        lr_float = float(tf.keras.backend.get_value(lr_value))
        if logs is not None:
            logs['learning_rate'] = lr_float


def create_callbacks():
    """
    ‚úÖ CORRECTED: Create callbacks with early stopping
    
    CHANGES:
    - ‚úÖ Added EarlyStopping with restore_best_weights (CRITICAL!)
    - ‚úÖ Added ReduceLROnPlateau as backup
    """
    callbacks = [
        # ‚úÖ CRITICAL: Early stopping with best weight restoration
        tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=EARLY_STOPPING_PATIENCE,
            restore_best_weights=True,  # ‚úÖ CRITICAL: Restore best model
            verbose=1,
            mode='min',
            start_from_epoch=WARMUP_EPOCHS  # Don't stop during warmup
        ),
        
        # Model checkpoint - save best model
        tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(CHECKPOINT_DIR, 'best_model.keras'),
            monitor='val_loss',
            save_best_only=True,
            mode='min',
            verbose=1,
            save_weights_only=False
        ),
        
        # Periodic checkpoints
        PeriodicCheckpoint(save_freq=SAVE_CHECKPOINT_EVERY),
        
        # Learning rate logger
        LearningRateLogger(),
        
        # TensorBoard
        tf.keras.callbacks.TensorBoard(
            log_dir=LOGS_DIR,
            histogram_freq=1,
            write_graph=True,
            update_freq='epoch',
            profile_batch=0  # Disable profiling to save memory
        ),
        
        # ‚úÖ NEW: Reduce LR on plateau as backup
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=5,
            min_lr=1e-7,
            verbose=1
        )
    ]
    
    return callbacks

print("‚úÖ Data pipeline & corrected callbacks defined")

##  7. Load Data

In [None]:



# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# CELL 7: LOAD DATA
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

print("\n" + "="*80)
print("üìÇ LOADING DATA")
print("="*80)

# Load training data
train_result = load_preprocessed_data('train')
if train_result is None:
    raise ValueError("Failed to load training data!")

train_X, train_y, train_metadata = train_result

# Load validation data
val_result = load_preprocessed_data('val')
if val_result is None:
    raise ValueError("Failed to load validation data!")

val_X, val_y, val_metadata = val_result

# Verify metadata matches
assert train_metadata['num_channels'] == val_metadata['num_channels'], \
    "Train and val have different channel counts!"
assert train_metadata['bandpower_dim'] == val_metadata['bandpower_dim'], \
    "Train and val have different bandpower dimensions!"

print(f"\n‚úÖ Data loaded successfully:")
print(f"   Train samples: {train_metadata['num_samples']:,}")
print(f"   Val samples: {val_metadata['num_samples']:,}")
print(f"   Channels: {train_metadata['num_channels']}")
print(f"   Output classes (channel): {train_metadata['num_output_channels']}")


##  8. Create Datasets

In [None]:


# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# CELL 8: CREATE DATASETS
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

print("\n" + "="*80)
print("üîÑ CREATING TF DATASETS")
print("="*80)

# Calculate steps per epoch
steps_per_epoch = len(train_X[0]) // BATCH_SIZE
val_steps = len(val_X[0]) // BATCH_SIZE

print(f"\nüìä Dataset info:")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Steps per epoch: {steps_per_epoch}")
print(f"   Validation steps: {val_steps}")
print(f"   Effective batch size (with {strategy.num_replicas_in_sync} GPUs): {BATCH_SIZE * strategy.num_replicas_in_sync}")

# Create datasets
train_dataset = create_tf_dataset(train_X, train_y, is_train=True)
val_dataset = create_tf_dataset(val_X, val_y, is_train=False)

print("‚úÖ Datasets created")

##  9.Build and Compile Model

In [None]:



print("\n" + "="*80)
print("üèóÔ∏è  BUILDING & COMPILING MODEL")
print("="*80)

# Build model inside strategy scope for multi-GPU
with strategy.scope():
    # Create model
    model = create_model(
        num_channels=train_metadata['num_channels'],
        bandpower_input_dim=train_metadata['bandpower_dim'],
        num_output_channels=train_metadata['num_output_channels']
    )
    
    # Learning rate schedule
    total_steps = EPOCHS * steps_per_epoch
    warmup_steps = WARMUP_EPOCHS * steps_per_epoch
    
    lr_schedule = WarmupCosineDecay(
        initial_learning_rate=INITIAL_LR,
        warmup_steps=warmup_steps,
        total_steps=total_steps
    )
    
    print(f"\nüìà Learning rate schedule:")
    print(f"   Initial LR: {INITIAL_LR}")
    print(f"   Warmup steps: {warmup_steps} ({WARMUP_EPOCHS} epochs)")
    print(f"   Total steps: {total_steps} ({EPOCHS} epochs)")
    
    # Optimizer
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=lr_schedule,
        weight_decay=WEIGHT_DECAY,
        clipnorm=GRADIENT_CLIP_NORM
    )
    
    # ‚úÖ Loss functions with label smoothing
    losses = {
        'channel': tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True,
            label_smoothing=0.1  # ‚úÖ NEW: Label smoothing
        ),
        'region': tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True,
            label_smoothing=0.1  # ‚úÖ NEW
        ),
        'band': tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True,
            label_smoothing=0.05  # ‚úÖ NEW: Less smoothing for easier task
        ),
        'state': tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True,
            label_smoothing=0.0  # ‚úÖ NEW: No smoothing for perfect task
        )
    }
    
    # ‚úÖ Task-specific loss weights - FOCUS ON HARDEST TASKS
    loss_weights = {
        'channel': 3.0,  # ‚úÖ NEW: 3x weight for hardest task
        'region': 1.5,   # ‚úÖ NEW: 1.5x weight
        'band': 1.0,     # Baseline
        'state': 0.5     # ‚úÖ NEW: Lower weight for easy task
    }
    
    print(f"\n‚öñÔ∏è  Loss configuration:")
    print(f"   Label smoothing: channel=0.1, region=0.1, band=0.05, state=0.0")
    print(f"   ‚úÖ Loss weights: channel=3.0, region=1.5, band=1.0, state=0.5")
    
    # Compile model
    model.compile(
        optimizer=optimizer,
        loss=losses,
        loss_weights=loss_weights,  # ‚úÖ NEW
        metrics={
            'channel': 'accuracy',
            'region': 'accuracy',
            'band': 'accuracy',
            'state': 'accuracy'
        }
    )

print(f"\n‚úÖ Model compiled successfully")
print(f"   Optimizer: AdamW (LR={INITIAL_LR}, weight_decay={WEIGHT_DECAY})")
print(f"   Gradient clipping: {GRADIENT_CLIP_NORM}")

# Print model summary
print(f"\nüìã Model summary:")
model.summary()


##  10. Training

In [None]:


print("\n" + "="*80)
print("üöÄ STARTING TRAINING")
print("="*80)

print(f"\nüéØ Training configuration:")
print(f"   Epochs: {EPOCHS}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Effective batch size: {BATCH_SIZE * strategy.num_replicas_in_sync}")
print(f"   Steps per epoch: {steps_per_epoch}")
print(f"   Total steps: {total_steps}")
print(f"   ‚úÖ Early stopping patience: {EARLY_STOPPING_PATIENCE} epochs")
print(f"   Strategy: {strategy.__class__.__name__} ({strategy.num_replicas_in_sync} replicas)")

# Create callbacks
callbacks = create_callbacks()

print(f"\nüìû Active callbacks:")
for cb in callbacks:
    print(f"   - {cb.__class__.__name__}")

# Record start time
start_time = time.time()
training_start = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

print(f"\n‚è∞ Training started at: {training_start}")
print("="*80)

# ‚úÖ TRAIN THE MODEL
try:
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=EPOCHS,
        callbacks=callbacks,
        verbose=1
    )
    
    # Calculate training duration
    training_duration = time.time() - start_time
    hours = int(training_duration // 3600)
    minutes = int((training_duration % 3600) // 60)
    seconds = int(training_duration % 60)
    
    print("\n" + "="*80)
    print("‚úÖ TRAINING COMPLETED!")
    print("="*80)
    print(f"‚è±Ô∏è  Total training time: {hours}h {minutes}m {seconds}s")
    print(f"   Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    # Get final metrics
    final_epoch = len(history.history['loss'])
    print(f"\nüìä Final metrics (epoch {final_epoch}):")
    print(f"   Train loss: {history.history['loss'][-1]:.4f}")
    print(f"   Val loss: {history.history['val_loss'][-1]:.4f}")
    print(f"   Channel acc: {history.history['channel_accuracy'][-1]:.4f} / {history.history['val_channel_accuracy'][-1]:.4f}")
    print(f"   Region acc: {history.history['region_accuracy'][-1]:.4f} / {history.history['val_region_accuracy'][-1]:.4f}")
    print(f"   Band acc: {history.history['band_accuracy'][-1]:.4f} / {history.history['val_band_accuracy'][-1]:.4f}")
    print(f"   State acc: {history.history['state_accuracy'][-1]:.4f} / {history.history['val_state_accuracy'][-1]:.4f}")
    
    # Check if early stopping was triggered
    if final_epoch < EPOCHS:
        print(f"\n‚ö†Ô∏è  Early stopping triggered at epoch {final_epoch}")
        print(f"   Best model restored (this is GOOD!)")
    
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Training interrupted by user")
    training_duration = time.time() - start_time
    print(f"   Training time before interruption: {int(training_duration)}s")

except Exception as e:
    print(f"\n‚ùå Training failed with error: {e}")
    import traceback
    traceback.print_exc()
    raise

##  11. Save Final Model and Results

In [None]:

print("\n" + "="*80)
print("üíæ SAVING RESULTS")
print("="*80)

# Save final model
final_model_path = os.path.join(CHECKPOINT_DIR, 'final_model.keras')
model.save(final_model_path)
print(f"‚úÖ Final model saved: {final_model_path}")

# Save training history
history_dict = {key: [float(val) for val in values] 
                for key, values in history.history.items()}

history_path = os.path.join(RESULTS_DIR, 'training_history.json')
with open(history_path, 'w') as f:
    json.dump(history_dict, f, indent=2)
print(f"‚úÖ Training history saved: {history_path}")

# Save configuration
config = {
    'model_architecture': {
        'name': 'CogniVue_Transformer_Corrected',
        'num_input_channels': train_metadata['num_channels'],
        'd_model': D_MODEL,
        'num_layers': NUM_LAYERS,
        'num_heads': NUM_HEADS,
        'ff_dim': FF_DIM,
        'dropout': DROPOUT,
        'window_size': WINDOW_SIZE_SAMPLES,
        'l2_regularization': L2_REGULARIZATION
    },
    'training_params': {
        'epochs_trained': len(history.history['loss']),
        'total_epochs': EPOCHS,
        'batch_size': BATCH_SIZE,
        'effective_batch_size': BATCH_SIZE * strategy.num_replicas_in_sync,
        'initial_lr': float(INITIAL_LR),
        'warmup_epochs': WARMUP_EPOCHS,
        'weight_decay': WEIGHT_DECAY,
        'gradient_clip_norm': GRADIENT_CLIP_NORM,
        'early_stopping_patience': EARLY_STOPPING_PATIENCE
    },
    'output_tasks': {
        'num_output_channels': train_metadata['num_output_channels'],
        'num_regions': NUM_OUTPUT_REGIONS,
        'num_bands': NUM_OUTPUT_BANDS,
        'num_states': NUM_OUTPUT_STATES
    },
    'loss_weights': {
        'channel': 3.0,
        'region': 1.5,
        'band': 1.0,
        'state': 0.5
    },
    'final_metrics': {
        'best_val_loss': float(min(history.history['val_loss'])),
        'final_train_loss': float(history.history['loss'][-1]),
        'final_val_loss': float(history.history['val_loss'][-1]),
        'final_channel_accuracy': float(history.history['channel_accuracy'][-1]),
        'final_val_channel_accuracy': float(history.history['val_channel_accuracy'][-1]),
        'final_region_accuracy': float(history.history['region_accuracy'][-1]),
        'final_val_region_accuracy': float(history.history['val_region_accuracy'][-1]),
        'final_band_accuracy': float(history.history['band_accuracy'][-1]),
        'final_val_band_accuracy': float(history.history['val_band_accuracy'][-1]),
        'final_state_accuracy': float(history.history['state_accuracy'][-1]),
        'final_val_state_accuracy': float(history.history['val_state_accuracy'][-1])
    },
    'training_info': {
        'started_at': training_start,
        'completed_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        'duration_seconds': int(time.time() - start_time),
        'tensorflow_version': tf.__version__,
        'strategy': strategy.__class__.__name__,
        'num_gpus': strategy.num_replicas_in_sync
    }
}

config_path = os.path.join(RESULTS_DIR, 'training_config.json')
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2)
print(f"‚úÖ Configuration saved: {config_path}")

# Create summary markdown
summary_md = f"""# CogniVue Training Summary (‚úÖ CORRECTED)

**Training Completed:** {config['training_info']['completed_at']}

## Model Architecture (‚úÖ Corrected)
- **Model:** CogniVue Transformer (Corrected)
- **Input Channels:** {train_metadata['num_channels']}
- **Model Dimension:** {D_MODEL} ‚úÖ (reduced from 256)
- **Transformer Layers:** {NUM_LAYERS} ‚úÖ (reduced from 6)
- **Attention Heads:** {NUM_HEADS}
- **Feedforward Dim:** {FF_DIM} ‚úÖ (reduced from 1024)
- **Dropout:** {DROPOUT} ‚úÖ (increased from 0.15)
- **L2 Regularization:** {L2_REGULARIZATION} ‚úÖ (NEW)

## Training Configuration (‚úÖ Corrected)
- **Epochs:** {len(history.history['loss'])}/{EPOCHS}
- **Batch Size:** {BATCH_SIZE} ‚úÖ (reduced from 64)
- **Effective Batch Size:** {BATCH_SIZE * strategy.num_replicas_in_sync} ({strategy.num_replicas_in_sync} GPUs)
- **Initial LR:** {INITIAL_LR} ‚úÖ (reduced from 1e-4)
- **Warmup Epochs:** {WARMUP_EPOCHS}
- **Weight Decay:** {WEIGHT_DECAY}
- **Early Stopping Patience:** {EARLY_STOPPING_PATIENCE} ‚úÖ (NEW)

## Loss Configuration (‚úÖ Corrected)
- **Loss Weights:** channel=3.0, region=1.5, band=1.0, state=0.5 ‚úÖ (NEW)
- **Label Smoothing:** channel=0.1, region=0.1, band=0.05, state=0.0 ‚úÖ (NEW)

## Final Performance
- **Best Val Loss:** {config['final_metrics']['best_val_loss']:.4f}
- **Final Train Loss:** {config['final_metrics']['final_train_loss']:.4f}
- **Final Val Loss:** {config['final_metrics']['final_val_loss']:.4f}

### Task-Specific Accuracy
- **Channel:** {config['final_metrics']['final_channel_accuracy']:.4f} / {config['final_metrics']['final_val_channel_accuracy']:.4f}
- **Region:** {config['final_metrics']['final_region_accuracy']:.4f} / {config['final_metrics']['final_val_region_accuracy']:.4f}
- **Band:** {config['final_metrics']['final_band_accuracy']:.4f} / {config['final_metrics']['final_val_band_accuracy']:.4f}
- **State:** {config['final_metrics']['final_state_accuracy']:.4f} / {config['final_metrics']['final_val_state_accuracy']:.4f}

## Training Info
- **Duration:** {config['training_info']['duration_seconds']} seconds
- **Strategy:** {config['training_info']['strategy']}
- **GPUs Used:** {config['training_info']['num_gpus']}
- **TensorFlow:** {config['training_info']['tensorflow_version']}

## Output Files
- `checkpoints/best_model.keras` - Best model weights (restored by early stopping)
- `checkpoints/final_model.keras` - Final model weights
- `checkpoints/checkpoint_epoch_*.keras` - Periodic checkpoints
- `results/training_history.json` - Loss and metrics per epoch
- `results/training_config.json` - Full configuration
- `logs/` - TensorBoard logs
"""

summary_path = os.path.join(RESULTS_DIR, 'TRAINING_SUMMARY.md')
with open(summary_path, 'w') as f:
    f.write(summary_md)
print(f"‚úÖ Summary saved: {summary_path}")

print("\n" + "="*80)
print("‚úÖ ALL RESULTS SAVED")
print("="*80)

##  11. Package and Download all Outputs

In [None]:

# ============================================================
# Section 11: Package & Download All Outputs
# ============================================================

import zipfile
from pathlib import Path

print("\n Creating download package...")
print("=" * 70)

# Create zip filename with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
zip_filename = f"cognivue_training_outputs_{timestamp}.zip"
zip_path = os.path.join(WORKING_DIR, zip_filename)

try:
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        
        # Add all files from results directory
        print("\n Adding results...")
        if os.path.exists(RESULTS_DIR):
            for file in os.listdir(RESULTS_DIR):
                file_path = os.path.join(RESULTS_DIR, file)
                if os.path.isfile(file_path):
                    arcname = os.path.join('results', file)
                    zipf.write(file_path, arcname)
                    print(f"   {file}")
        
        # Add all checkpoint files
        print("\nüîñ Adding checkpoints...")
        if os.path.exists(CHECKPOINT_DIR):
            for file in os.listdir(CHECKPOINT_DIR):
                file_path = os.path.join(CHECKPOINT_DIR, file)
                if os.path.isfile(file_path):
                    arcname = os.path.join('checkpoints', file)
                    zipf.write(file_path, arcname)
                    file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
                    print(f"   {file} ({file_size_mb:.1f} MB)")
        
        # Add TensorBoard logs (optional - can be large)
        print("\n Adding TensorBoard logs...")
        if os.path.exists(LOGS_DIR):
            log_count = 0
            for root, dirs, files in os.walk(LOGS_DIR):
                for file in files:
                    file_path = os.path.join(root, file)
                    arcname = os.path.join('logs', os.path.relpath(file_path, LOGS_DIR))
                    zipf.write(file_path, arcname)
                    log_count += 1
            print(f"   Added {log_count} log files")
    
    # Get final zip size
    zip_size_mb = os.path.getsize(zip_path) / (1024 * 1024)
    
    print("\n" + "=" * 70)
    print(" PACKAGE CREATED SUCCESSFULLY!")
    print("=" * 70)
    print(f"\n Zip file: {zip_filename}")
    print(f" Size: {zip_size_mb:.1f} MB")
    print(f" Location: {zip_path}")
    
    print("\n To download:")
    print("   1. Go to the 'Output' tab (top right)")
    print("   2. Click 'Save Version' to commit outputs")
    print(f"   3. Download '{zip_filename}'")
    print("\n Or click the download icon next to the file in the Output tab")
    
    # List contents
    print("\n Package contents:")
    with zipfile.ZipFile(zip_path, 'r') as zipf:
        file_list = zipf.namelist()
        print(f"   Total files: {len(file_list)}")
        print("\n   Structure:")
        print("   ‚îú‚îÄ‚îÄ results/")
        print("   ‚îÇ   ‚îú‚îÄ‚îÄ training_history.json")
        print("   ‚îÇ   ‚îú‚îÄ‚îÄ training_config.json")
        print("   ‚îÇ   ‚îî‚îÄ‚îÄ TRAINING_SUMMARY.md")
        print("   ‚îú‚îÄ‚îÄ checkpoints/")
        print("   ‚îÇ   ‚îú‚îÄ‚îÄ best_model.keras")
        print("   ‚îÇ   ‚îú‚îÄ‚îÄ final_model.keras")
        print("   ‚îÇ   ‚îî‚îÄ‚îÄ checkpoint_epoch_*.keras")
        print("   ‚îî‚îÄ‚îÄ logs/")
        print("       ‚îî‚îÄ‚îÄ TensorBoard logs")
    
    print("\n" + "=" * 70)
    
except Exception as e:
    print(f"\n‚ùå Error creating zip: {e}")
    import traceback
    traceback.print_exc()

##  12  Create download Package 

In [None]:



# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# CELL 12: CREATE DOWNLOAD PACKAGE
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

import zipfile

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"cognivue_corrected_outputs_{timestamp}.zip"
zip_path = os.path.join(WORKING_DIR, zip_filename)

print("\n" + "="*80)
print("üì¶ CREATING DOWNLOAD PACKAGE")
print("="*80)

try:
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        
        # Add results
        print("\nüìÑ Adding results...")
        if os.path.exists(RESULTS_DIR):
            for file in os.listdir(RESULTS_DIR):
                file_path = os.path.join(RESULTS_DIR, file)
                if os.path.isfile(file_path):
                    arcname = os.path.join('results', file)
                    zipf.write(file_path, arcname)
                    print(f"   ‚úÖ {file}")
        
        # Add checkpoints
        print("\nüîñ Adding checkpoints...")
        if os.path.exists(CHECKPOINT_DIR):
            for file in os.listdir(CHECKPOINT_DIR):
                file_path = os.path.join(CHECKPOINT_DIR, file)
                if os.path.isfile(file_path):
                    arcname = os.path.join('checkpoints', file)
                    zipf.write(file_path, arcname)
                    file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
                    print(f"   ‚úÖ {file} ({file_size_mb:.1f} MB)")
        
        # Add TensorBoard logs
        print("\nüìä Adding TensorBoard logs...")
        if os.path.exists(LOGS_DIR):
            log_count = 0
            for root, dirs, files in os.walk(LOGS_DIR):
                for file in files:
                    file_path = os.path.join(root, file)
                    arcname = os.path.join('logs', os.path.relpath(file_path, LOGS_DIR))
                    zipf.write(file_path, arcname)
                    log_count += 1
            print(f"   ‚úÖ Added {log_count} log files")
    
    zip_size_mb = os.path.getsize(zip_path) / (1024 * 1024)
    
    print("\n" + "="*80)
    print("‚úÖ PACKAGE CREATED SUCCESSFULLY!")
    print("="*80)
    print(f"\nüì¶ Zip file: {zip_filename}")
    print(f"   Size: {zip_size_mb:.1f} MB")
    print(f"   Location: {zip_path}")
    
    print("\nüì• To download:")
    print("   1. Go to the 'Output' tab (top right)")
    print("   2. Click 'Save Version' to commit outputs")
    print(f"   3. Download '{zip_filename}'")
    
except Exception as e:
    print(f"\n‚ùå Error creating zip: {e}")
    import traceback
    traceback.print_exc()


##  13. Link


In [None]:


# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# CELL 13: DISPLAY DOWNLOAD LINK
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

from IPython.display import FileLink, display

print("\n" + "="*80)
print("üéâ TRAINING COMPLETE - READY FOR DOWNLOAD")
print("="*80)

# Find the latest zip file
import glob
zip_files = glob.glob(os.path.join(WORKING_DIR, 'cognivue_corrected_outputs_*.zip'))

if zip_files:
    latest_zip = max(zip_files, key=os.path.getctime)
    relative_path = os.path.basename(latest_zip)
    
    print(f"\n‚úÖ Download package ready:")
    print(f"   {relative_path}")
    print("\nüì• Click link below to download:")
    display(FileLink(relative_path))
    
    print("\n" + "="*80)
    print("EXPECTED IMPROVEMENTS vs PREVIOUS TRAINING:")
    print("="*80)
    print("‚úÖ Validation loss: Should be ~0.65-0.75 (was 1.22)")
    print("‚úÖ Channel accuracy: Should be ~68-75% (was 56%)")
    print("‚úÖ Training stops: Around epoch 20-25 (was 26+)")
    print("‚úÖ Overfitting: Reduced by 70-80%")
    print("\nAll corrections have been applied!")
    
else:
    print("‚ùå No output zip files found")

