# Component 08b: DenseNet121 + Spatial Attention (Advanced)## 🔥 Novel Architecture EnhancementThis notebook adds **Spatial Attention Mechanism** to DenseNet121 for improved feature focus.**Attention Benefits**:- ✅ Learns WHERE to focus in the image- ✅ Improves interpretability- ✅ Better performance on subtle features- ✅ State-of-the-art for medical imaging**Citation**: Woo et al. (2018), "CBAM: Convolutional Block Attention Module" 

In [None]:
import tensorflow as tfimport pandas as pdimport numpy as npimport matplotlib.pyplot as pltimport json, os, mathfrom sklearn.utils.class_weight import compute_class_weightSEED = 42tf.random.set_seed(SEED)np.random.seed(SEED)tf.keras.mixed_precision.set_global_policy('mixed_float16')OUTPUT_DIR = '../outputs'os.makedirs(f'{OUTPUT_DIR}/models', exist_ok=True)os.makedirs(f'{OUTPUT_DIR}/training_history', exist_ok=True)print('✅ Setup complete with mixed precision')

## Configuration

In [None]:
train_df = pd.read_csv('../outputs/train_manifest.csv')val_df = pd.read_csv('../outputs/val_manifest.csv')IMG_SIZE, BATCH_SIZE, EPOCHS = (224, 224), 32, 50NUM_CLASSES = len(train_df['class_label'].unique())INITIAL_LR, WEIGHT_DECAY, WARMUP_EPOCHS = 1e-3, 1e-4, 5LABEL_SMOOTHING, GRADIENT_CLIP_NORM = 0.1, 1.0DROPOUT_RATE_1, DROPOUT_RATE_2, DROPOUT_RATE_3 = 0.3, 0.4, 0.5print(f'Train: {len(train_df)}, Val: {len(val_df)}, Classes: {NUM_CLASSES}')

## Spatial Attention Module

In [None]:
class SpatialAttention(tf.keras.layers.Layer):    """Spatial Attention Module - learns WHERE to focus in the image."""        def __init__(self, kernel_size=7, **kwargs):        super().__init__(**kwargs)        self.kernel_size = kernel_size            def build(self, input_shape):        self.conv = tf.keras.layers.Conv2D(            filters=1,            kernel_size=self.kernel_size,            padding='same',            activation='sigmoid',            kernel_initializer='he_normal'        )        super().build(input_shape)        def call(self, inputs):        # Generate attention map from channel-wise statistics        avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)        max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)        concat = tf.concat([avg_pool, max_pool], axis=-1)                # Compute spatial attention weights        attention = self.conv(concat)                # Apply attention to inputs        return inputs * attentionprint('✅ Spatial Attention Module defined')print('   - Learns spatial importance across image regions')print('   - 7×7 conv for large receptive field')print('   - Sigmoid activation for attention weights [0,1]')

## Preprocessing & Data

In [None]:
def preprocess(fp, label):    img = tf.io.read_file(fp)    img = tf.image.decode_jpeg(img, channels=3)    img = tf.image.resize(img, IMG_SIZE)    img = tf.keras.applications.densenet.preprocess_input(img)    return img, labelaug = tf.keras.Sequential([    tf.keras.layers.RandomFlip('horizontal'),    tf.keras.layers.RandomRotation(0.15),    tf.keras.layers.RandomZoom(0.2),    tf.keras.layers.RandomContrast(0.2),    tf.keras.layers.RandomBrightness(0.2)])def build_dataset(df, augment=False, shuffle=True):    ds = tf.data.Dataset.from_tensor_slices((df['filepath'].values, df['class_label'].values))    ds = ds.map(preprocess, tf.data.AUTOTUNE).cache()    if augment:        ds = ds.map(lambda x, y: (aug(x, training=True), y))    if shuffle:        ds = ds.shuffle(1000, seed=SEED)    return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)train_ds = build_dataset(train_df, augment=True)val_ds = build_dataset(val_df, augment=False, shuffle=False)print('✅ Datasets ready')

## Build DenseNet121 + Attention

In [None]:
# Load pretrained base modelbase_model = tf.keras.applications.DenseNet121(    include_top=False,    weights='imagenet',    input_shape=(*IMG_SIZE, 3))# Fine-tune last 30 layersfor layer in base_model.layers[:-30]:    layer.trainable = False# Build model with spatial attentioninputs = tf.keras.Input(shape=(*IMG_SIZE, 3))# Extract features from base modelx = base_model(inputs, training=True)# Reshape for attention (if using global pooling, need to reshape back to spatial)# For models with global pooling, we'll apply attention before pooling# So we need to get features before poolingbase_no_pool = tf.keras.applications.DenseNet121(    include_top=False,    pooling=None,  # No pooling - keep spatial dimensions    weights='imagenet',    input_shape=(*IMG_SIZE, 3))for layer in base_no_pool.layers[:-30]:    layer.trainable = False# Rebuild with attentioninputs = tf.keras.Input(shape=(*IMG_SIZE, 3))features = base_no_pool(inputs, training=True)# Apply Spatial Attentionattention_features = SpatialAttention(kernel_size=7)(features)# Global pooling after attentionx = tf.keras.layers.GlobalAveragePooling2D()(attention_features)# Classification head with dropoutx = tf.keras.layers.Dropout(DROPOUT_RATE_1)(x)x = tf.keras.layers.Dense(256, activation='relu')(x)x = tf.keras.layers.Dropout(DROPOUT_RATE_2)(x)x = tf.keras.layers.Dense(128, activation='relu')(x)x = tf.keras.layers.Dropout(DROPOUT_RATE_3)(x)outputs = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax', dtype='float32')(x)model = tf.keras.Model(inputs, outputs)print(f'✅ DenseNet121 + Spatial Attention created')print(f'   Total params: {model.count_params():,}')model.summary()

## Label Smoothing Loss

In [None]:
class LabelSmoothingLoss(tf.keras.losses.Loss):    def __init__(self, num_classes, smoothing=0.1):        super().__init__()        self.num_classes = num_classes        self.smoothing = smoothing            def call(self, y_true, y_pred):        y_true = tf.cast(y_true, tf.int32)        y_true_one_hot = tf.one_hot(y_true, self.num_classes)        y_true_smooth = y_true_one_hot * (1 - self.smoothing) + self.smoothing / self.num_classes        return tf.keras.losses.categorical_crossentropy(y_true_smooth, y_pred)loss_fn = LabelSmoothingLoss(NUM_CLASSES, smoothing=LABEL_SMOOTHING)print(f'✅ Label Smoothing Loss created (ε={LABEL_SMOOTHING})')

## Compile

In [None]:
def get_lr_schedule(epoch, lr):    if epoch < WARMUP_EPOCHS:        return INITIAL_LR * (epoch + 1) / WARMUP_EPOCHS    progress = (epoch - WARMUP_EPOCHS) / (EPOCHS - WARMUP_EPOCHS)    return INITIAL_LR * 0.5 * (1 + math.cos(math.pi * progress))class_weights = compute_class_weight('balanced', classes=np.unique(train_df['class_label']), y=train_df['class_label'])class_weight_dict = {i: w for i, w in enumerate(class_weights)}optimizer = tf.keras.optimizers.AdamW(learning_rate=INITIAL_LR, weight_decay=WEIGHT_DECAY, clipnorm=GRADIENT_CLIP_NORM)model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])print('✅ Compiled with AdamW + Attention-enhanced architecture')

## Train

In [None]:
callbacks = [    tf.keras.callbacks.ModelCheckpoint(f'{OUTPUT_DIR}/models/densenet121_attention_best.h5', monitor='val_accuracy', save_best_only=True, mode='max', verbose=1),    tf.keras.callbacks.LearningRateScheduler(get_lr_schedule, verbose=1),    tf.keras.callbacks.CSVLogger(f'{OUTPUT_DIR}/training_history/densenet121_attention_training.csv')]print(f'\n{"="*80}')print(f'🚀 TRAINING DenseNet121 + SPATIAL ATTENTION')print(f'{"="*80}')print(f'Novel Feature: Spatial attention learns to focus on important brain regions')print(f'{"="*80}\n')history = model.fit(    train_ds,    validation_data=val_ds,    epochs=EPOCHS,    callbacks=callbacks,    class_weight=class_weight_dict,    verbose=1)print(f'\n{"="*80}')print('✅ TRAINING COMPLETE')print(f'{"="*80}')

## Save Results

In [None]:
# Save historywith open(f'{OUTPUT_DIR}/training_history/densenet121_attention_history.json', 'w') as f:    json.dump({k: [float(v) for v in vals] for k, vals in history.history.items()}, f, indent=2)import pandas as pdpd.DataFrame(history.history).to_csv(f'{OUTPUT_DIR}/training_history/densenet121_attention_history.csv', index=False)# Plot training curvesfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))ax1.plot(history.history['accuracy'], label='Train', linewidth=2)ax1.plot(history.history['val_accuracy'], label='Validation', linewidth=2)ax1.set_title('Accuracy (with Spatial Attention)', fontsize=14, fontweight='bold')ax1.set_xlabel('Epoch')ax1.set_ylabel('Accuracy')ax1.legend()ax1.grid(alpha=0.3)ax2.plot(history.history['loss'], label='Train', linewidth=2)ax2.plot(history.history['val_loss'], label='Validation', linewidth=2)ax2.set_title('Loss (with Spatial Attention)', fontsize=14, fontweight='bold')ax2.set_xlabel('Epoch')ax2.set_ylabel('Loss')ax2.legend()ax2.grid(alpha=0.3)plt.suptitle(f'DenseNet121 + Spatial Attention Training', fontsize=16, fontweight='bold')plt.tight_layout()plt.savefig(f'{OUTPUT_DIR}/training_history/densenet121_attention_curves.png', dpi=200)plt.show()print(f'\n📊 Results:')print(f'   Best Val Accuracy: {max(history.history["val_accuracy"]):.4f}')print(f'   Final Train Acc: {history.history["accuracy"][-1]:.4f}')print(f'   Final Val Acc: {history.history["val_accuracy"][-1]:.4f}')print(f'\n💾 Saved to: {OUTPUT_DIR}/models/densenet121_attention_best.h5')print(f'\n🔬 Compare with base DenseNet121 to measure attention impact!')