# 03 — Train Per‑Cell CNN (MobileNetV2)

Train a 13‑class classifier on patches at `data/final/train|val/<CLASS>/*.jpg`.

- Works on **Kaggle** (`/kaggle/input` + `/kaggle/working`) and **local repo**.
- Saves model to `models/cell_cnn.h5` (or `/kaggle/working/models/cell_cnn.h5`).


In [None]:
# %%capture
# !pip install --quiet tensorflow==2.* opencv-python albumentations tqdm


In [1]:
print(">>> CELL STARTED")

import os, sys, glob
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

ON_KAGGLE = Path('/kaggle').exists()
ROOT = Path('/kaggle/working') if ON_KAGGLE else Path('..')

TRAIN_DIR = ROOT / 'data/final/train'
VAL_DIR   = ROOT / 'data/final/val'
MODEL_DIR = ROOT / 'models'
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / 'cell_cnn.h5'

IMG_SIZE = (96, 96)
BATCH = 64
EPOCHS = 10

print('Train dir:', TRAIN_DIR)
print('Val   dir:', VAL_DIR)
print('Model   :', MODEL_PATH)


>>> CELL STARTED
Train dir: ..\data\final\train
Val   dir: ..\data\final\val
Model   : ..\models\cell_cnn.h5


In [2]:
# Build datasets (expects class subfolders).
train_ds = image_dataset_from_directory(
    TRAIN_DIR, labels='inferred', label_mode='int',
    image_size=IMG_SIZE, batch_size=BATCH, shuffle=True)
val_ds = image_dataset_from_directory(
    VAL_DIR, labels='inferred', label_mode='int',
    image_size=IMG_SIZE, batch_size=BATCH, shuffle=False)

class_names = train_ds.class_names
num_classes = len(class_names)
print('Classes:', class_names)

# Prefetch + map preprocess_input
AUTOTUNE = tf.data.AUTOTUNE
def prep(x,y):
    return tf.keras.applications.mobilenet_v2.preprocess_input(tf.cast(x, tf.float32)), y
train_ds = train_ds.map(prep).prefetch(AUTOTUNE)
val_ds   = val_ds.map(prep).prefetch(AUTOTUNE)


Found 6600 files belonging to 13 classes.
Found 729 files belonging to 13 classes.
Classes: ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']


In [3]:
import tensorflow as tf

base = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet'
)
base.trainable = False

inp = tf.keras.layers.Input(IMG_SIZE + (3,))
x = base(inp, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
out = tf.keras.layers.Dense(num_classes, activation='softmax')(x)

model = tf.keras.Model(inp, out)
model.compile(optimizer=tf.keras.optimizers.Adam(5e-4),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.summary()


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_96_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step


In [4]:
ckpt = ModelCheckpoint(str(MODEL_PATH), monitor='val_accuracy', save_best_only=True, verbose=1)
es   = EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)
hist = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, callbacks=[ckpt, es])
print('Best model saved to:', MODEL_PATH)


Epoch 1/10
[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 138ms/step - accuracy: 0.4160 - loss: 1.8995
Epoch 1: val_accuracy improved from None to 0.88477, saving model to ..\models\cell_cnn.h5




[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 171ms/step - accuracy: 0.6229 - loss: 1.2336 - val_accuracy: 0.8848 - val_loss: 0.4801
Epoch 2/10
[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 155ms/step - accuracy: 0.8686 - loss: 0.4851
Epoch 2: val_accuracy improved from 0.88477 to 0.92044, saving model to ..\models\cell_cnn.h5




[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 176ms/step - accuracy: 0.8823 - loss: 0.4365 - val_accuracy: 0.9204 - val_loss: 0.3248
Epoch 3/10
[1m103/104[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 159ms/step - accuracy: 0.9206 - loss: 0.3186
Epoch 3: val_accuracy improved from 0.92044 to 0.92867, saving model to ..\models\cell_cnn.h5




[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 178ms/step - accuracy: 0.9230 - loss: 0.3129 - val_accuracy: 0.9287 - val_loss: 0.2758
Epoch 4/10
[1m103/104[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 152ms/step - accuracy: 0.9369 - loss: 0.2587
Epoch 4: val_accuracy improved from 0.92867 to 0.94650, saving model to ..\models\cell_cnn.h5




[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 170ms/step - accuracy: 0.9400 - loss: 0.2540 - val_accuracy: 0.9465 - val_loss: 0.2537
Epoch 5/10
[1m103/104[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 150ms/step - accuracy: 0.9475 - loss: 0.2250
Epoch 5: val_accuracy did not improve from 0.94650
[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 166ms/step - accuracy: 0.9471 - loss: 0.2254 - val_accuracy: 0.9424 - val_loss: 0.2427
Epoch 6/10
[1m103/104[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 150ms/step - accuracy: 0.9531 - loss: 0.2031
Epoch 6: val_accuracy improved from 0.94650 to 0.94925, saving model to ..\models\cell_cnn.h5




[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 168ms/step - accuracy: 0.9527 - loss: 0.1991 - val_accuracy: 0.9492 - val_loss: 0.2329
Epoch 7/10
[1m103/104[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 155ms/step - accuracy: 0.9610 - loss: 0.1763
Epoch 7: val_accuracy improved from 0.94925 to 0.95336, saving model to ..\models\cell_cnn.h5




[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 173ms/step - accuracy: 0.9591 - loss: 0.1836 - val_accuracy: 0.9534 - val_loss: 0.2249
Epoch 8/10
[1m103/104[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 152ms/step - accuracy: 0.9606 - loss: 0.1692
Epoch 8: val_accuracy did not improve from 0.95336
[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 167ms/step - accuracy: 0.9612 - loss: 0.1675 - val_accuracy: 0.9534 - val_loss: 0.2177
Epoch 9/10
[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 187ms/step - accuracy: 0.9651 - loss: 0.1482
Epoch 9: val_accuracy did not improve from 0.95336
[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 205ms/step - accuracy: 0.9647 - loss: 0.1527 - val_accuracy: 0.9520 - val_loss: 0.2129
Epoch 10/10
[1m104/104[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 192ms/step - accuracy: 0.9665 - loss: 0.1452
Epoch 10: val_accuracy did not improve from 0.95336
[1m104/104[0m 

### (Optional) Fine‑tune
Unfreeze the base for a few epochs if you want a small boost.


In [None]:
# Optional fine‑tune a few layers
unfreeze = False
if unfreeze:
    base.trainable = True
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
                  loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    hist2 = model.fit(train_ds, validation_data=val_ds, epochs=3, callbacks=[ckpt])


In [5]:
print('✅ Done. Model at:', MODEL_PATH)


✅ Done. Model at: ..\models\cell_cnn.h5
