In [2]:
import tensorflow as tf
from sklearn.utils import class_weight
import numpy as np
import matplotlib.pyplot as plt
from tf_data_pipeline import gather_filepaths, split_by_patient, make_dataset


# --- CONFIG ---
ROOT = "BreaKHis_v1/histology_slides/breast"
MAG = "40X"
EPOCHS = 10
BATCH = 32
IMG_SIZE = (224, 224)
SEED = 42

# --- Load Dataset ---
items = gather_filepaths(ROOT, mag=MAG)
train_pairs, val_pairs, test_pairs = split_by_patient(items)

train_ds = make_dataset(train_pairs, augment_data=True)
val_ds = make_dataset(val_pairs, augment_data=False)
test_ds = make_dataset(test_pairs, augment_data=False)

print(f"Train: {len(train_pairs)} | Validation: {len(val_pairs)} | Test: {len(test_pairs)}")


# --- Compute Class Weights ---
y_train = [label for _, label in train_pairs]
class_weights = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)
class_weights = dict(enumerate(class_weights))
print("\nClass Weights:", class_weights)

Train: 1512 | Validation: 270 | Test: 213

Class Weights: {0: np.float64(1.2989690721649485), 1: np.float64(0.8129032258064516)}


In [3]:
# --- Base Model (Frozen) ---
base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2B0(
    include_top=False,
    weights='imagenet',
    input_shape=(*IMG_SIZE, 3)
)

base_model.trainable = False  # Freeze base initially

# --- Build Transfer Learning Model ---
inputs = tf.keras.Input(shape=(*IMG_SIZE, 3))
x = tf.keras.applications.efficientnet.preprocess_input(inputs)
x = base_model(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.4)(x)
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)

model = tf.keras.Model(inputs, outputs)
model.summary()

# --- Compile ---
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

In [4]:
# --- Train Phase 1 (frozen base) ---
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    # class_weight=class_weights
)

Epoch 1/10


[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 1s/step - accuracy: 0.5952 - loss: 0.6860 - val_accuracy: 0.8926 - val_loss: 0.5772
Epoch 2/10
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 1s/step - accuracy: 0.6012 - loss: 0.6735 - val_accuracy: 0.8926 - val_loss: 0.5235
Epoch 3/10
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 1s/step - accuracy: 0.5972 - loss: 0.6791 - val_accuracy: 0.8926 - val_loss: 0.5007
Epoch 4/10
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 1s/step - accuracy: 0.5979 - loss: 0.6856 - val_accuracy: 0.8926 - val_loss: 0.5833
Epoch 5/10
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 1s/step - accuracy: 0.6111 - loss: 0.6779 - val_accuracy: 0.8926 - val_loss: 0.5269
Epoch 6/10
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 1s/step - accuracy: 0.5913 - loss: 0.6795 - val_accuracy: 0.8926 - val_loss: 0.4864
Epoch 7/10
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━

In [5]:
# --- Fine-tuning Phase ---
base_model.trainable = True
fine_tune_at = len(base_model.layers) // 2  # unfreeze last half
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

fine_tune_epochs = 5
total_epochs = EPOCHS + fine_tune_epochs

history_fine = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=total_epochs,
    initial_epoch=history.epoch[-1],
    # class_weight=class_weights
)

Epoch 10/15
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m200s[0m 3s/step - accuracy: 0.4974 - loss: 0.7107 - val_accuracy: 0.1481 - val_loss: 0.7408
Epoch 11/15
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m135s[0m 3s/step - accuracy: 0.5476 - loss: 0.6893 - val_accuracy: 0.1556 - val_loss: 0.7322
Epoch 12/15
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m140s[0m 3s/step - accuracy: 0.5761 - loss: 0.6792 - val_accuracy: 0.2222 - val_loss: 0.7114
Epoch 13/15
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m135s[0m 3s/step - accuracy: 0.5886 - loss: 0.6675 - val_accuracy: 0.8926 - val_loss: 0.6434
Epoch 14/15
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m131s[0m 3s/step - accuracy: 0.5860 - loss: 0.6680 - val_accuracy: 0.3852 - val_loss: 0.6998
Epoch 15/15
[1m48/48[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m131s[0m 3s/step - accuracy: 0.5866 - loss: 0.6703 - val_accuracy: 0.7778 - val_loss: 0.6611
