In [3]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import os
import glob
import nibabel as nib
from sklearn.model_selection import train_test_split
from scipy.ndimage import zoom

# Disable XLA compilation for faster startup
tf.config.optimizer.set_jit(False)

# Set memory growth to avoid GPU memory issues
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(f"GPU memory growth setting failed: {e}")

# Enable mixed precision for better performance
tf.keras.mixed_precision.set_global_policy('mixed_float16')

class OptimizedSEBlock3D(layers.Layer):
    """Lightweight 3D Squeeze-and-Excitation Block"""
    
    def __init__(self, channels, reduction=16, **kwargs):
        super(OptimizedSEBlock3D, self).__init__(**kwargs)
        self.channels = channels
        self.reduction = reduction
        
        # Use smaller reduction ratio for efficiency
        reduced_channels = max(channels // reduction, 4)
        
        # Global average pooling
        self.global_avg_pool = layers.GlobalAveragePooling3D(keepdims=True)
        
        # Simplified squeeze and excitation
        self.squeeze = layers.Dense(reduced_channels, activation='relu')
        self.excitation = layers.Dense(channels, activation='sigmoid')
        
    def call(self, inputs):
        # Squeeze: Global average pooling
        squeezed = self.global_avg_pool(inputs)
        squeezed = tf.reshape(squeezed, [-1, self.channels])
        
        # Excitation: Dense layers
        excited = self.squeeze(squeezed)
        excited = self.excitation(excited)
        excited = tf.reshape(excited, [-1, 1, 1, 1, self.channels])
        
        # Scale original input
        return inputs * excited


class LightweightParallelConvBlock(layers.Layer):
    """Lightweight Parallel Convolution Block - simplified for performance"""
    
    def __init__(self, filters, **kwargs):
        super(LightweightParallelConvBlock, self).__init__(**kwargs)
        self.filters = filters
        
        # Use only 1x1 and 3x3 convolutions (remove 5x5 for speed)
        self.conv_1x1 = layers.Conv3D(
            filters // 2, 
            kernel_size=(1, 1, 1), 
            padding='same',
            kernel_regularizer=keras.regularizers.l2(0.01)
        )
        self.conv_3x3 = layers.Conv3D(
            filters // 2, 
            kernel_size=(3, 3, 3), 
            padding='same',
            kernel_regularizer=keras.regularizers.l2(0.01)
        )
        
        # Batch normalization instead of layer normalization for speed
        self.batch_norm = layers.BatchNormalization()
        self.leaky_relu = layers.LeakyReLU(alpha=0.1)
        self.dropout = layers.Dropout(rate=0.1)  # Reduced dropout
        self.max_pool = layers.MaxPooling3D(pool_size=(2, 2, 2))
        
    def call(self, inputs, training=None):
        # Parallel branches (only 1x1 and 3x3)
        branch_1x1 = self.conv_1x1(inputs)
        branch_3x3 = self.conv_3x3(inputs)
        
        # Concatenate branches
        concat = layers.concatenate([branch_1x1, branch_3x3], axis=-1)
        
        # Apply normalization and activation
        normalized = self.batch_norm(concat, training=training)
        activated = self.leaky_relu(normalized)
        dropped = self.dropout(activated, training=training)
        
        # Max pooling for encoder
        pooled = self.max_pool(dropped)
        
        return dropped, pooled


class OptimizedEncoderBlock(layers.Layer):
    """Optimized Encoder block with single convolution for speed"""
    
    def __init__(self, filters, **kwargs):
        super(OptimizedEncoderBlock, self).__init__(**kwargs)
        self.filters = filters
        
        # Single convolution instead of double for speed
        self.conv = layers.Conv3D(
            filters, 
            kernel_size=(3, 3, 3), 
            padding='same',
            kernel_regularizer=keras.regularizers.l2(0.01)
        )
        
        # Batch normalization for speed
        self.batch_norm = layers.BatchNormalization()
        self.leaky_relu = layers.LeakyReLU(alpha=0.1)
        self.dropout = layers.Dropout(rate=0.1)
        
        # Lightweight SE attention
        self.se_block = OptimizedSEBlock3D(filters)
        
        # Max pooling
        self.max_pool = layers.MaxPooling3D(pool_size=(2, 2, 2))
        
    def call(self, inputs, training=None):
        # Single convolution
        x = self.conv(inputs)
        x = self.batch_norm(x, training=training)
        x = self.leaky_relu(x)
        x = self.dropout(x, training=training)
        
        # SE attention
        x = self.se_block(x)
        
        # Store skip connection before pooling
        skip = x
        
        # Max pooling
        pooled = self.max_pool(x)
        
        return skip, pooled


class WeightedDiceLoss(keras.losses.Loss):
    """Optimized Weighted Dice Loss"""
    
    def __init__(self, class_weights=None, smooth=1e-6, **kwargs):
        super(WeightedDiceLoss, self).__init__(**kwargs)
        
        if class_weights is None:
            self.class_weights = tf.constant([1.0, 2.0, 2.0, 3.0], dtype=tf.float32)
        else:
            self.class_weights = tf.constant(class_weights, dtype=tf.float32)
        
        self.smooth = smooth
        
    def call(self, y_true, y_pred):
        # Ensure float32 for mixed precision
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        # Calculate dice coefficient for each class
        dice_scores = []
        
        for i in range(4):  # 4 classes
            y_true_class = y_true[..., i]
            y_pred_class = y_pred[..., i]
            
            intersection = tf.reduce_sum(y_true_class * y_pred_class, axis=[1,2,3])
            union = tf.reduce_sum(y_true_class, axis=[1,2,3]) + tf.reduce_sum(y_pred_class, axis=[1,2,3])
            
            dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
            dice_scores.append(tf.reduce_mean(dice))
        
        # Convert to tensor
        dice_scores = tf.stack(dice_scores)
        
        # Apply class weights
        weighted_dice = dice_scores * self.class_weights
        
        # Return negative weighted average (loss to minimize)
        return -tf.reduce_mean(weighted_dice)


def create_optimized_latup_net(input_shape=(64, 64, 64, 3)):  # Reduced input size
    """
    Create Performance-Optimized LATUP-Net model
    
    Args:
        input_shape: Input shape (height, width, depth, channels) - reduced to 64^3
    
    Returns:
        Keras model
    """
    inputs = keras.Input(shape=input_shape)
    
    # Input normalization
    x = layers.Lambda(lambda x: tf.nn.sigmoid(x))(inputs)
    
    # Encoder path with reduced complexity
    # Level 1: 64x64x64 -> 32x32x32
    pc_block = LightweightParallelConvBlock(16)  # Reduced filters
    skip1, x = pc_block(x)  # skip1: (64,64,64,16), x: (32,32,32,16)
    
    # Level 2: 32x32x32 -> 16x16x16  
    enc_block2 = OptimizedEncoderBlock(32)  # Reduced filters
    skip2, x = enc_block2(x)  # skip2: (32,32,32,32), x: (16,16,16,32)
    
    # Level 3: 16x16x16 -> 8x8x8
    enc_block3 = OptimizedEncoderBlock(64)  # Reduced filters
    skip3, x = enc_block3(x)  # skip3: (16,16,16,64), x: (8,8,8,64)
    
    # Bottleneck: 8x8x8 -> 4x4x4
    x = layers.Conv3D(128, (3, 3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(0.01))(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Dropout(0.1)(x)
    x = OptimizedSEBlock3D(128)(x)
    
    # Pool to bottleneck
    x = layers.MaxPooling3D(pool_size=(2, 2, 2))(x)  # (4,4,4,128)
    
    # Decoder path with corrected skip connections
    # Level 3: 4x4x4 -> 8x8x8, connect with skip3 (16,16,16,64)
    x = layers.UpSampling3D(size=(2, 2, 2))(x)  # (8,8,8,128)
    # Downsample skip3 to match: (16,16,16,64) -> (8,8,8,64)
    skip3_downsampled = layers.MaxPooling3D(pool_size=(2, 2, 2))(skip3)
    x = layers.concatenate([x, skip3_downsampled], axis=-1)  # (8,8,8,128+64=192)
    x = layers.Conv3D(64, (1, 1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Dropout(0.1)(x)
    x = OptimizedSEBlock3D(64)(x)
    
    # Level 2: 8x8x8 -> 16x16x16, connect with skip2 (32,32,32,32)
    x = layers.UpSampling3D(size=(2, 2, 2))(x)  # (16,16,16,64)
    # Downsample skip2 to match: (32,32,32,32) -> (16,16,16,32)
    skip2_downsampled = layers.MaxPooling3D(pool_size=(2, 2, 2))(skip2)
    x = layers.concatenate([x, skip2_downsampled], axis=-1)  # (16,16,16,64+32=96)
    x = layers.Conv3D(32, (1, 1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Dropout(0.1)(x)
    x = OptimizedSEBlock3D(32)(x)
    
    # Level 1: 16x16x16 -> 32x32x32, connect with skip1 (64,64,64,16)
    x = layers.UpSampling3D(size=(2, 2, 2))(x)  # (32,32,32,32)
    # Downsample skip1 to match: (64,64,64,16) -> (32,32,32,16)
    skip1_downsampled = layers.MaxPooling3D(pool_size=(2, 2, 2))(skip1)
    x = layers.concatenate([x, skip1_downsampled], axis=-1)  # (32,32,32,32+16=48)
    x = layers.Conv3D(16, (1, 1, 1), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Dropout(0.1)(x)
    x = OptimizedSEBlock3D(16)(x)
    
    # Final upsampling: 32x32x32 -> 64x64x64
    x = layers.UpSampling3D(size=(2, 2, 2))(x)  # (64,64,64,16)
    
    # Final classification layer
    outputs = layers.Conv3D(
        4, 
        kernel_size=1, 
        activation='softmax',
        dtype=tf.float32  # Ensure float32 output for mixed precision
    )(x)
    
    # Create model
    model = keras.Model(inputs=inputs, outputs=outputs, name='Optimized-LATUP-Net')
    
    return model


# Custom Keras Metrics for dice score monitoring (simplified)
class FastDiceScore(keras.metrics.Metric):
    """Fast Keras metric for overall dice score"""
    
    def __init__(self, name='dice_score', **kwargs):
        super(FastDiceScore, self).__init__(name=name, **kwargs)
        self.dice_sum = self.add_weight(name='dice_sum', initializer='zeros', dtype=tf.float32)
        self.count = self.add_weight(name='count', initializer='zeros', dtype=tf.float32)
        self.smooth = 1e-6
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        # Convert to float32
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        # Simplified dice calculation
        intersection = tf.reduce_sum(y_true * y_pred)
        union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        
        self.dice_sum.assign_add(dice)
        self.count.assign_add(1)
        
    def result(self):
        return self.dice_sum / self.count
    
    def reset_state(self):
        self.dice_sum.assign(0)
        self.count.assign(0)


def resize_volume_fast(volume, target_shape=(64, 64, 64)):
    """Fast volume resizing using nearest neighbor"""
    current_shape = volume.shape
    zoom_factors = [target_shape[i] / current_shape[i] for i in range(3)]
    return zoom(volume, zoom_factors, order=0)  # order=0 for nearest neighbor (faster)


def load_brats_case_optimized(case_path):
    """
    Optimized BraTS case loading with reduced resolution
    
    Args:
        case_path: Path to case directory
    
    Returns:
        input_volume: (64, 64, 64, 3) - T1ce, T2, FLAIR (reduced size)
        mask: (64, 64, 64) - segmentation mask (reduced size)
    """
    case_name = os.path.basename(case_path)
    
    # Load modalities
    t1ce_path = os.path.join(case_path, f"{case_name}_t1ce.nii")
    t2_path = os.path.join(case_path, f"{case_name}_t2.nii")
    flair_path = os.path.join(case_path, f"{case_name}_flair.nii")
    seg_path = os.path.join(case_path, f"{case_name}_seg.nii")
    
    # Check if all files exist
    if not all(os.path.exists(path) for path in [t1ce_path, t2_path, flair_path, seg_path]):
        print(f"Missing files for case: {case_name}")
        return None, None
    
    try:
        # Load volumes
        t1ce = nib.load(t1ce_path).get_fdata()
        t2 = nib.load(t2_path).get_fdata()
        flair = nib.load(flair_path).get_fdata()
        seg = nib.load(seg_path).get_fdata()
        
        # Resize to smaller target shape for performance
        t1ce = resize_volume_fast(t1ce, target_shape=(64, 64, 64))
        t2 = resize_volume_fast(t2, target_shape=(64, 64, 64))
        flair = resize_volume_fast(flair, target_shape=(64, 64, 64))
        seg = resize_volume_fast(seg, target_shape=(64, 64, 64))
        
        # Fast normalization
        def normalize_fast(volume):
            volume = volume.astype(np.float32)
            if volume.max() > 0:
                volume = volume / volume.max()
            return volume
        
        t1ce = normalize_fast(t1ce)
        t2 = normalize_fast(t2)
        flair = normalize_fast(flair)
        
        # Stack modalities
        input_volume = np.stack([t1ce, t2, flair], axis=-1)
        
        # Convert segmentation labels
        seg_processed = np.zeros_like(seg)
        seg_processed[seg == 1] = 1  # necrotic core
        seg_processed[seg == 2] = 2  # peritumoral edema  
        seg_processed[seg == 4] = 3  # enhancing tumor
        
        return input_volume.astype(np.float32), seg_processed.astype(np.int32)
    
    except Exception as e:
        print(f"Error loading case {case_path}: {e}")
        return None, None


def create_optimized_brats_generator(case_paths, batch_size=1, shuffle=True):
    """Optimized data generator for BraTS dataset"""
    
    def data_generator():
        indices = np.arange(len(case_paths))
        if shuffle:
            np.random.shuffle(indices)
        
        for i in range(0, len(indices), batch_size):
            batch_indices = indices[i:i+batch_size]
            batch_x = []
            batch_y = []
            
            for idx in batch_indices:
                case_path = case_paths[idx]
                x, y = load_brats_case_optimized(case_path)
                if x is not None and y is not None:
                    batch_x.append(x)
                    # Convert to one-hot encoding
                    y_onehot = tf.one_hot(y, depth=4).numpy()
                    batch_y.append(y_onehot)
            
            if batch_x:
                yield np.array(batch_x), np.array(batch_y)
    
    return data_generator


def create_optimized_tf_dataset(generator_func, num_cases, batch_size=1):
    """Create optimized TensorFlow dataset from generator"""
    output_signature = (
        tf.TensorSpec(shape=(batch_size, 64, 64, 64, 3), dtype=tf.float32),  # Reduced size
        tf.TensorSpec(shape=(batch_size, 64, 64, 64, 4), dtype=tf.float32)
    )
    
    dataset = tf.data.Dataset.from_generator(
        generator_func,
        output_signature=output_signature
    )
    
    # Optimize dataset performance
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    dataset = dataset.cache()  # Cache for faster access
    
    return dataset


def compile_optimized_model(model):
    """Compile model with optimized settings"""
    
    # Use simpler metrics for faster training
    metrics = [
        'accuracy',
        FastDiceScore(name='dice_score')
    ]
    
    # Compile with optimized settings
    model.compile(
        optimizer=keras.optimizers.Adam(
            learning_rate=1e-3,  # Higher learning rate for faster convergence
            clipnorm=1.0  # Gradient clipping for stability
        ),
        loss=WeightedDiceLoss(),
        metrics=metrics
    )
    
    return model


def train_optimized_model(model, train_data, val_data=None, epochs=50):
    """Train model with optimized settings"""
    
    # Optimized callbacks
    callbacks = [
        keras.callbacks.ModelCheckpoint(
            'best_optimized_latup.weights.h5',
            monitor='val_dice_score' if val_data else 'dice_score',
            save_best_only=True,
            save_weights_only=True,
            mode='max',
            verbose=1
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor='val_dice_score' if val_data else 'dice_score',
            factor=0.5,
            patience=5,  # Reduced patience
            min_lr=1e-6,
            mode='max',
            verbose=1
        ),
        keras.callbacks.EarlyStopping(
            monitor='val_dice_score' if val_data else 'dice_score',
            patience=10,  # Reduced patience
            restore_best_weights=True,
            mode='max',
            verbose=1
        )
    ]
    
    print("🚀 Starting Optimized LATUP-Net training...")
    print("✅ Performance optimizations applied:")
    print("   • Reduced model complexity")
    print("   • Mixed precision training") 
    print("   • Smaller input size (64³)")
    print("   • Batch normalization")
    print("   • Simplified SE blocks")
    print("   • XLA disabled for faster startup")
    print("   • GPU memory growth enabled")
    
    history = model.fit(
        train_data,
        validation_data=val_data,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    
    return history


# Main optimized training script
if __name__ == "__main__":
    print("🔧 Performance-Optimized LATUP-Net")
    print("=" * 50)
    
    # Create optimized model
    print("Creating optimized model...")
    model = create_optimized_latup_net(input_shape=(64, 64, 64, 3))
    model = compile_optimized_model(model)
    
    # Test model creation
    print("Testing model...")
    dummy_input = tf.random.normal((1, 64, 64, 64, 3))
    dummy_output = model(dummy_input)
    print(f"✅ Model test passed! Output shape: {dummy_output.shape}")
    
    # Print model summary
    model.summary()
    
    # Data directory
    DATA_DIR = "/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
    
    if not os.path.exists(DATA_DIR):
        print(f"Data directory not found: {DATA_DIR}")
        print("Testing with dummy data...")
        
        # Create dummy data for testing
        dummy_train_data = tf.data.Dataset.from_tensor_slices((
            tf.random.normal((4, 64, 64, 64, 3)),
            tf.random.uniform((4, 64, 64, 64, 4), maxval=1.0)
        )).batch(1).prefetch(tf.data.AUTOTUNE)
        
        dummy_val_data = tf.data.Dataset.from_tensor_slices((
            tf.random.normal((2, 64, 64, 64, 3)),
            tf.random.uniform((2, 64, 64, 64, 4), maxval=1.0)
        )).batch(1).prefetch(tf.data.AUTOTUNE)
        
        # Test training
        print("Testing training with dummy data...")
        history = train_optimized_model(model, dummy_train_data, dummy_val_data, epochs=2)
        print("✅ Optimized training test completed!")
    
    else:
        # Load real data
        print("Loading BraTS dataset...")
        case_paths = glob.glob(os.path.join(DATA_DIR, "BraTS20_Training_*"))[:20]  # Use only 20 cases for testing
        print(f"Using {len(case_paths)} cases for testing")
        
        # Split data
        train_paths, val_paths = train_test_split(case_paths, test_size=0.2, random_state=42)
        
        # Create optimized generators
        train_gen = create_optimized_brats_generator(train_paths, batch_size=1, shuffle=True)
        val_gen = create_optimized_brats_generator(val_paths, batch_size=1, shuffle=False)
        
        # Create datasets
        train_dataset = create_optimized_tf_dataset(train_gen, len(train_paths), batch_size=1)
        val_dataset = create_optimized_tf_dataset(val_gen, len(val_paths), batch_size=1)
        
        # Train model
        history = train_optimized_model(model, train_dataset, val_dataset, epochs=100)
        
        # Save model
        model.save('optimized_latup_net.h5')
        print("✅ Optimized model saved!")
    
   

🔧 Performance-Optimized LATUP-Net
Creating optimized model...
Testing model...
✅ Model test passed! Output shape: (1, 64, 64, 64, 4)


Loading BraTS dataset...
Using 20 cases for testing
🚀 Starting Optimized LATUP-Net training...
✅ Performance optimizations applied:
   • Reduced model complexity
   • Mixed precision training
   • Smaller input size (64³)
   • Batch normalization
   • Simplified SE blocks
   • XLA disabled for faster startup
   • GPU memory growth enabled
Epoch 1/100
     16/Unknown [1m34s[0m 134ms/step - accuracy: 0.2032 - dice_score: 0.2366 - loss: 1.1439
Epoch 1: val_dice_score improved from -inf to 0.23068, saving model to best_optimized_latup.weights.h5
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 424ms/step - accuracy: 0.2092 - dice_score: 0.2376 - loss: 1.1275 - val_accuracy: 0.0048 - val_dice_score: 0.2307 - val_loss: 0.3788 - learning_rate: 0.0010
Epoch 2/100
[1m13/16[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m0s[0m 15ms/step - accuracy: 0.5334 - dice_score: 0.2976 - loss: 0.2473
Epoch 2: val_dice_score improved from 0.23068 to 0.24772, saving model to best_optimize