# Component 5: Train Baseline CNN

50 epochs, ModelCheckpoint + ReduceLROnPlateau, **NO EarlyStopping**

In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from sklearn.utils.class_weight import compute_class_weight

SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)

OUTPUT_DIR = '../outputs'
os.makedirs(f'{OUTPUT_DIR}/models', exist_ok=True)
os.makedirs(f'{OUTPUT_DIR}/training_history', exist_ok=True)
print('✓ Setup complete')

✓ Setup complete


## Load Data

In [2]:
train_df = pd.read_csv('../outputs/train_manifest.csv')
val_df = pd.read_csv('../outputs/val_manifest.csv')

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 50
NUM_CLASSES = len(train_df['class_label'].unique())

print(f'Train: {{len(train_df)}} images')
print(f'Val:   {{len(val_df)}} images')
print(f'Classes: {{NUM_CLASSES}}')

Train: {len(train_df)} images
Val:   {len(val_df)} images
Classes: {NUM_CLASSES}


## Preprocessing & Augmentation

In [3]:
def preprocess(filepath, label):
    img = tf.io.read_file(filepath)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    img = img / 255.0
    return img, label

augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip('horizontal'),
    tf.keras.layers.RandomRotation(0.2),
    tf.keras.layers.RandomZoom(0.2)
])

def build_dataset(df, augment=False, shuffle=True):
    ds = tf.data.Dataset.from_tensor_slices((df['filepath'].values, df['class_label'].values))
    ds = ds.map(preprocess, tf.data.AUTOTUNE).cache()
    if augment:
        ds = ds.map(lambda x, y: (augmentation(x, training=True), y))
    if shuffle:
        ds = ds.shuffle(1000, seed=SEED)
    return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

train_ds = build_dataset(train_df, augment=True, shuffle=True)
val_ds = build_dataset(val_df, augment=False, shuffle=False)
print('✓ Datasets created')

2026-01-28 23:46:02.058671: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3
2026-01-28 23:46:02.058714: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2026-01-28 23:46:02.058724: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.92 GB
2026-01-28 23:46:02.058749: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2026-01-28 23:46:02.058765: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


✓ Datasets created


## Build Baseline CNN Model

In [4]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(*IMG_SIZE, 3)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Dropout(0.25),
    
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Dropout(0.25),
    
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Dropout(0.25),
    
    tf.keras.layers.Conv2D(256, (3,3), activation='relu'),
    tf.keras.layers.GlobalAveragePooling2D(),
    
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

model.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


## Compile & Setup

In [5]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Compute class weights
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(train_df['class_label']),
    y=train_df['class_label']
)
class_weight_dict = {i: w for i, w in enumerate(class_weights)}
print('Class weights:', class_weight_dict)

Class weights: {0: 1.3105867346938775, 1: 3.277511961722488, 2: 5.838068181818182, 3: 0.36224219989423584}


## Callbacks (NO EarlyStopping!)

In [6]:
# ModelCheckpoint - saves best model
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    f'{OUTPUT_DIR}/models/baseline_cnn_best.h5',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

# ReduceLROnPlateau - reduces learning rate when stuck
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

callbacks = [checkpoint, reduce_lr]
print('✓ Callbacks configured (NO EarlyStopping)')

✓ Callbacks configured (NO EarlyStopping)


## Train Model

In [7]:
print(f'\nStarting training for {EPOCHS} epochs...')
print('='*60)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
    class_weight=class_weight_dict,
    verbose=1
)

print('\n' + '='*60)
print('✅ TRAINING COMPLETE')
print('='*60)


Starting training for 50 epochs...
Epoch 1/50


2026-01-28 23:46:04.343574: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


[1m257/257[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 688ms/step - accuracy: 0.3794 - loss: 1.2789
Epoch 1: val_accuracy improved from None to 0.68995, saving model to ../outputs/models/baseline_cnn_best.h5




[1m257/257[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m200s[0m 719ms/step - accuracy: 0.4202 - loss: 1.2273 - val_accuracy: 0.6899 - val_loss: 1.3719 - learning_rate: 1.0000e-04
Epoch 2/50
[1m 11/257[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m2:43[0m 664ms/step - accuracy: 0.4151 - loss: 1.2767

KeyboardInterrupt: 

## Save History & Plots

In [None]:
# Save history as JSON and CSV
history_dict = history.history
history_path = f'{OUTPUT_DIR}/training_history/baseline_cnn_history.json'
csv_path = f'{OUTPUT_DIR}/training_history/baseline_cnn_history.csv'

with open(history_path, 'w') as f:
    json.dump(history_dict, f, indent=2)

import pandas as pd
pd.DataFrame(history_dict).to_csv(csv_path, index=False)

print(f'✓ History saved to {history_path}')
print(f'✓ History saved to {csv_path}')

# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy
ax1.plot(history_dict['accuracy'], label='Train', linewidth=2)
ax1.plot(history_dict['val_accuracy'], label='Validation', linewidth=2)
ax1.set_title('Model Accuracy', fontsize=14, fontweight='bold')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.legend()
ax1.grid(alpha=0.3)

# Loss
ax2.plot(history_dict['loss'], label='Train', linewidth=2)
ax2.plot(history_dict['val_loss'], label='Validation', linewidth=2)
ax2.set_title('Model Loss', fontsize=14, fontweight='bold')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.legend()
ax2.grid(alpha=0.3)

plt.suptitle(f'Baseline CNN Training Progress', fontsize=16, fontweight='bold')
plt.tight_layout()

plot_path = f'{OUTPUT_DIR}/training_history/baseline_cnn_curves.png'
plt.savefig(plot_path, dpi=200, bbox_inches='tight')
plt.show()

print(f'✓ Training curves saved to {plot_path}')
print(f'\nBest model saved to: {OUTPUT_DIR}/models/baseline_cnn_best.h5')
print(f'Best val_accuracy: {max(history_dict["val_accuracy"]):.4f}')