# 03 — Train Per‑Cell CNN (MobileNetV2)

Train a 13‑class classifier on patches at `data/final/train|val/<CLASS>/*.jpg`.

- Works on **Kaggle** (`/kaggle/input` + `/kaggle/working`) and **local repo**.
- Saves model to `models/cell_cnn.h5` (or `/kaggle/working/models/cell_cnn.h5`).


In [None]:
# %%capture
# !pip install --quiet tensorflow==2.* opencv-python albumentations tqdm


In [1]:
print(">>> CELL STARTED")

import os, sys, glob
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory # type: ignore
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping # type: ignore
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input # type: ignore

ON_KAGGLE = Path('/kaggle').exists()
ROOT = Path('/kaggle/working') if ON_KAGGLE else Path('..')

TRAIN_DIR = ROOT / 'data/final/train'
VAL_DIR   = ROOT / 'data/final/val'
MODEL_DIR = ROOT / 'models'
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / 'cell_cnn.h5'

IMG_SIZE = (96, 96)
BATCH = 64
EPOCHS = 10

print('Train dir:', TRAIN_DIR)
print('Val   dir:', VAL_DIR)
print('Model   :', MODEL_PATH)


>>> CELL STARTED
Train dir: ..\data\final\train
Val   dir: ..\data\final\val
Model   : ..\models\cell_cnn.h5


In [2]:
# Build datasets (expects class subfolders).
train_ds = image_dataset_from_directory(
    TRAIN_DIR, labels='inferred', label_mode='int',
    image_size=IMG_SIZE, batch_size=BATCH, shuffle=True)
val_ds = image_dataset_from_directory(
    VAL_DIR, labels='inferred', label_mode='int',
    image_size=IMG_SIZE, batch_size=BATCH, shuffle=False)

class_names = train_ds.class_names
num_classes = len(class_names)
print('Classes:', class_names)

# Prefetch + map preprocess_input
AUTOTUNE = tf.data.AUTOTUNE
def prep(x,y):
    return tf.keras.applications.mobilenet_v2.preprocess_input(tf.cast(x, tf.float32)), y
train_ds = train_ds.map(prep).prefetch(AUTOTUNE)
val_ds   = val_ds.map(prep).prefetch(AUTOTUNE)


Found 13632 files belonging to 13 classes.
Found 2069 files belonging to 13 classes.
Classes: ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']


In [3]:
import tensorflow as tf

base = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet'
)
base.trainable = False

inp = tf.keras.layers.Input(IMG_SIZE + (3,))
x = base(inp, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
out = tf.keras.layers.Dense(num_classes, activation='softmax')(x)

model = tf.keras.Model(inp, out)
model.compile(optimizer=tf.keras.optimizers.Adam(5e-4),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.summary()


In [4]:
ckpt = ModelCheckpoint(str(MODEL_PATH), monitor='val_accuracy', save_best_only=True, verbose=1)
es   = EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)
hist = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, callbacks=[ckpt, es])
print('Best model saved to:', MODEL_PATH)


Epoch 1/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 321ms/step - accuracy: 0.5910 - loss: 1.3398
Epoch 1: val_accuracy improved from None to 0.94248, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m88s[0m 392ms/step - accuracy: 0.7805 - loss: 0.7576 - val_accuracy: 0.9425 - val_loss: 0.2544
Epoch 2/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 205ms/step - accuracy: 0.9406 - loss: 0.2619
Epoch 2: val_accuracy improved from 0.94248 to 0.96230, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m51s[0m 239ms/step - accuracy: 0.9459 - loss: 0.2343 - val_accuracy: 0.9623 - val_loss: 0.1775
Epoch 3/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 206ms/step - accuracy: 0.9599 - loss: 0.1760
Epoch 3: val_accuracy improved from 0.96230 to 0.97197, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m51s[0m 240ms/step - accuracy: 0.9617 - loss: 0.1651 - val_accuracy: 0.9720 - val_loss: 0.1430
Epoch 4/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 200ms/step - accuracy: 0.9683 - loss: 0.1471
Epoch 4: val_accuracy improved from 0.97197 to 0.97438, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 233ms/step - accuracy: 0.9704 - loss: 0.1355 - val_accuracy: 0.9744 - val_loss: 0.1243
Epoch 5/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 198ms/step - accuracy: 0.9741 - loss: 0.1127
Epoch 5: val_accuracy improved from 0.97438 to 0.97583, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 231ms/step - accuracy: 0.9750 - loss: 0.1110 - val_accuracy: 0.9758 - val_loss: 0.1123
Epoch 6/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 194ms/step - accuracy: 0.9786 - loss: 0.0991
Epoch 6: val_accuracy improved from 0.97583 to 0.97728, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 228ms/step - accuracy: 0.9782 - loss: 0.0961 - val_accuracy: 0.9773 - val_loss: 0.1061
Epoch 7/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 198ms/step - accuracy: 0.9800 - loss: 0.0893
Epoch 7: val_accuracy improved from 0.97728 to 0.97922, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 232ms/step - accuracy: 0.9809 - loss: 0.0860 - val_accuracy: 0.9792 - val_loss: 0.0972
Epoch 8/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 204ms/step - accuracy: 0.9815 - loss: 0.0772
Epoch 8: val_accuracy improved from 0.97922 to 0.97970, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 231ms/step - accuracy: 0.9819 - loss: 0.0767 - val_accuracy: 0.9797 - val_loss: 0.0923
Epoch 9/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 200ms/step - accuracy: 0.9861 - loss: 0.0657
Epoch 9: val_accuracy improved from 0.97970 to 0.98115, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 234ms/step - accuracy: 0.9849 - loss: 0.0666 - val_accuracy: 0.9812 - val_loss: 0.0879
Epoch 10/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 181ms/step - accuracy: 0.9849 - loss: 0.0599
Epoch 10: val_accuracy did not improve from 0.98115
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 211ms/step - accuracy: 0.9857 - loss: 0.0599 - val_accuracy: 0.9812 - val_loss: 0.0844
Best model saved to: ..\models\cell_cnn.h5


### (Optional) Fine‑tune
Unfreeze the base for a few epochs if you want a small boost.


In [None]:
# Optional fine‑tune a few layers
unfreeze = False
if unfreeze:
    base.trainable = True
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
                  loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    hist2 = model.fit(train_ds, validation_data=val_ds, epochs=3, callbacks=[ckpt])


In [5]:
print('✅ Done. Model at:', MODEL_PATH)


✅ Done. Model at: ..\models\cell_cnn.h5


In [6]:
# SAVE CLASS ORDER (robust, handles IMG_SIZE int/tuple)
import json
from pathlib import Path
from tensorflow.keras.utils import image_dataset_from_directory # type: ignore

# --- ensure dirs/vars ---
if 'MODEL_DIR' not in locals():
    MODEL_DIR = Path('models')
MODEL_DIR.mkdir(parents=True, exist_ok=True)
CLASSES_JSON = MODEL_DIR / "classes.json"

# --- normalize image_size to a 2-int tuple ---
def _as_hw_tuple(x):
    # x may be int (96) or tuple like (96,96)
    if isinstance(x, (tuple, list)) and len(x) == 2:
        return (int(x[0]), int(x[1]))
    return (int(x), int(x))

IMG_HW = _as_hw_tuple(IMG_SIZE)

# 1) try to read from train_ds
class_names = None
if 'train_ds' in locals():
    try:
        class_names = list(train_ds.class_names)  # only works before map/prefetch
    except Exception:
        class_names = None

# 2) fallback: build a temporary dataset from folder just to get class_names
if class_names is None:
    if not Path(TRAIN_DIR).exists():
        raise FileNotFoundError(f"TRAIN_DIR not found: {TRAIN_DIR}")
    tmp_ds = image_dataset_from_directory(
        TRAIN_DIR,
        labels='inferred',
        label_mode='int',
        image_size=IMG_HW,      # <— use normalized (h,w)
        batch_size=32,
        shuffle=False
    )
    class_names = list(tmp_ds.class_names)

# 3) save to JSON
CLASSES_JSON.write_text(
    json.dumps(class_names, ensure_ascii=False, indent=2),
    encoding="utf-8"
)
print("✅ Saved class order to:", CLASSES_JSON)
print("   class_names =", class_names)


Found 13632 files belonging to 13 classes.
✅ Saved class order to: ..\models\classes.json
   class_names = ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']
