# 03 — Train Per‑Cell CNN (MobileNetV2)

Train a 13‑class classifier on patches at `data/final/train|val/<CLASS>/*.jpg`.

- Works on **Kaggle** (`/kaggle/input` + `/kaggle/working`) and **local repo**.
- Saves model to `models/cell_cnn.h5` (or `/kaggle/working/models/cell_cnn.h5`).


In [1]:
# %%capture
# !pip install --quiet tensorflow==2.* opencv-python albumentations tqdm


In [2]:
print(">>> CELL STARTED")

import os, sys, glob
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory # type: ignore
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping # type: ignore
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input # type: ignore

ON_KAGGLE = Path('/kaggle').exists()
ROOT = Path('/kaggle/working') if ON_KAGGLE else Path('..')

TRAIN_DIR = ROOT / 'data/final/train'
VAL_DIR   = ROOT / 'data/final/val'
MODEL_DIR = ROOT / 'models'
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / 'cell_cnn.h5'

IMG_SIZE = (96, 96)
BATCH = 64
EPOCHS = 10

print('Train dir:', TRAIN_DIR)
print('Val   dir:', VAL_DIR)
print('Model   :', MODEL_PATH)


>>> CELL STARTED
Train dir: ..\data\final\train
Val   dir: ..\data\final\val
Model   : ..\models\cell_cnn.h5


In [3]:
# Build datasets (expects class subfolders).
train_ds = image_dataset_from_directory(
    TRAIN_DIR, labels='inferred', label_mode='int',
    image_size=IMG_SIZE, batch_size=BATCH, shuffle=True)
val_ds = image_dataset_from_directory(
    VAL_DIR, labels='inferred', label_mode='int',
    image_size=IMG_SIZE, batch_size=BATCH, shuffle=False)

class_names = train_ds.class_names
num_classes = len(class_names)
print('Classes:', class_names)

# Prefetch + map preprocess_input
AUTOTUNE = tf.data.AUTOTUNE
def prep(x,y):
    return tf.keras.applications.mobilenet_v2.preprocess_input(tf.cast(x, tf.float32)), y
train_ds = train_ds.map(prep).prefetch(AUTOTUNE)
val_ds   = val_ds.map(prep).prefetch(AUTOTUNE)


Found 13632 files belonging to 13 classes.
Found 2069 files belonging to 13 classes.
Classes: ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']


In [4]:
import tensorflow as tf

base = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet'
)
base.trainable = False

inp = tf.keras.layers.Input(IMG_SIZE + (3,))
x = base(inp, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
out = tf.keras.layers.Dense(num_classes, activation='softmax')(x)

model = tf.keras.Model(inp, out)
model.compile(optimizer=tf.keras.optimizers.Adam(5e-4),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.summary()


In [5]:
ckpt = ModelCheckpoint(str(MODEL_PATH), monitor='val_accuracy', save_best_only=True, verbose=1)
es   = EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)
hist = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, callbacks=[ckpt, es])
print('Best model saved to:', MODEL_PATH)


Epoch 1/10


[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 209ms/step - accuracy: 0.5884 - loss: 1.3830
Epoch 1: val_accuracy improved from None to 0.94393, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m56s[0m 246ms/step - accuracy: 0.7831 - loss: 0.7617 - val_accuracy: 0.9439 - val_loss: 0.2605
Epoch 2/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 145ms/step - accuracy: 0.9350 - loss: 0.2628
Epoch 2: val_accuracy improved from 0.94393 to 0.96085, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 168ms/step - accuracy: 0.9427 - loss: 0.2358 - val_accuracy: 0.9609 - val_loss: 0.1768
Epoch 3/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 149ms/step - accuracy: 0.9601 - loss: 0.1724
Epoch 3: val_accuracy improved from 0.96085 to 0.97003, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 173ms/step - accuracy: 0.9616 - loss: 0.1662 - val_accuracy: 0.9700 - val_loss: 0.1464
Epoch 4/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 148ms/step - accuracy: 0.9678 - loss: 0.1346
Epoch 4: val_accuracy improved from 0.97003 to 0.97438, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 170ms/step - accuracy: 0.9687 - loss: 0.1316 - val_accuracy: 0.9744 - val_loss: 0.1305
Epoch 5/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 143ms/step - accuracy: 0.9749 - loss: 0.1157
Epoch 5: val_accuracy improved from 0.97438 to 0.97487, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 165ms/step - accuracy: 0.9748 - loss: 0.1137 - val_accuracy: 0.9749 - val_loss: 0.1164
Epoch 6/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 140ms/step - accuracy: 0.9802 - loss: 0.0943
Epoch 6: val_accuracy improved from 0.97487 to 0.97632, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 162ms/step - accuracy: 0.9790 - loss: 0.0958 - val_accuracy: 0.9763 - val_loss: 0.1074
Epoch 7/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 137ms/step - accuracy: 0.9813 - loss: 0.0844
Epoch 7: val_accuracy improved from 0.97632 to 0.97922, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 159ms/step - accuracy: 0.9809 - loss: 0.0848 - val_accuracy: 0.9792 - val_loss: 0.1016
Epoch 8/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 139ms/step - accuracy: 0.9823 - loss: 0.0753
Epoch 8: val_accuracy did not improve from 0.97922
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 160ms/step - accuracy: 0.9824 - loss: 0.0748 - val_accuracy: 0.9763 - val_loss: 0.0956
Epoch 9/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 131ms/step - accuracy: 0.9844 - loss: 0.0684
Epoch 9: val_accuracy did not improve from 0.97922
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 151ms/step - accuracy: 0.9844 - loss: 0.0677 - val_accuracy: 0.9778 - val_loss: 0.0909
Epoch 10/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 137ms/step - accuracy: 0.9876 - loss: 0.0575
Epoch 10: val_accuracy improved from 0.97922 to 0.98163, saving model



[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 159ms/step - accuracy: 0.9864 - loss: 0.0598 - val_accuracy: 0.9816 - val_loss: 0.0874
Best model saved to: ..\models\cell_cnn.h5


### (Optional) Fine‑tune
Unfreeze the base for a few epochs if you want a small boost.


In [None]:
# Optional fine‑tune a few layers
unfreeze = False
if unfreeze:
    base.trainable = True
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
                  loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    hist2 = model.fit(train_ds, validation_data=val_ds, epochs=3, callbacks=[ckpt])


In [6]:
print('✅ Done. Model at:', MODEL_PATH)


✅ Done. Model at: ..\models\cell_cnn.h5


In [7]:
# SAVE CLASS ORDER (robust, handles IMG_SIZE int/tuple)
import json
from pathlib import Path
from tensorflow.keras.utils import image_dataset_from_directory # type: ignore

# --- ensure dirs/vars ---
if 'MODEL_DIR' not in locals():
    MODEL_DIR = Path('models')
MODEL_DIR.mkdir(parents=True, exist_ok=True)
CLASSES_JSON = MODEL_DIR / "classes.json"

# --- normalize image_size to a 2-int tuple ---
def _as_hw_tuple(x):
    # x may be int (96) or tuple like (96,96)
    if isinstance(x, (tuple, list)) and len(x) == 2:
        return (int(x[0]), int(x[1]))
    return (int(x), int(x))

IMG_HW = _as_hw_tuple(IMG_SIZE)

# 1) try to read from train_ds
class_names = None
if 'train_ds' in locals():
    try:
        class_names = list(train_ds.class_names)  # only works before map/prefetch
    except Exception:
        class_names = None

# 2) fallback: build a temporary dataset from folder just to get class_names
if class_names is None:
    if not Path(TRAIN_DIR).exists():
        raise FileNotFoundError(f"TRAIN_DIR not found: {TRAIN_DIR}")
    tmp_ds = image_dataset_from_directory(
        TRAIN_DIR,
        labels='inferred',
        label_mode='int',
        image_size=IMG_HW,      # <— use normalized (h,w)
        batch_size=32,
        shuffle=False
    )
    class_names = list(tmp_ds.class_names)

# 3) save to JSON
CLASSES_JSON.write_text(
    json.dumps(class_names, ensure_ascii=False, indent=2),
    encoding="utf-8"
)
print("✅ Saved class order to:", CLASSES_JSON)
print("   class_names =", class_names)


Found 13632 files belonging to 13 classes.
✅ Saved class order to: ..\models\classes.json
   class_names = ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']


In [8]:
# --- Single/Multi-image sanity check (supports Empty, BP/BN/BB/WP/...) ---
from pathlib import Path
import json, random, cv2, numpy as np
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

# ===== paths =====
ROOT         = Path("..").resolve()
MODEL_PATH   = ROOT / "models/cell_cnn.h5"
CLASSES_JSON = ROOT / "models/classes.json"

# ===== load model & class order =====
model   = tf.keras.models.load_model(str(MODEL_PATH))
CLASSES = json.loads(CLASSES_JSON.read_text(encoding="utf-8"))
IMG_SIZE = 96  # ต้องตรงกับตอนเทรน/อินเฟอร์

print("classes:", CLASSES)

# ===== util: find candidate images for a given class =====
SEARCH_DIRS = [
    ROOT / "data/final/val",
    ROOT / "data/final/train",
    ROOT / "data/public/cells",
    ROOT / "data/bootstrap/cells",
]
def find_images_for_class(class_name: str):
    exts = ("*.jpg", "*.png", "*.jpeg")
    files = []
    for base in SEARCH_DIRS:
        p = base / class_name
        if p.exists():
            for ext in exts:
                files += sorted(p.glob(ext))
    return files

# ===== core predict =====
def _prep_tensor(bgr, size=IMG_SIZE):
    rgb = cv2.cvtColor(cv2.resize(bgr, (size, size)), cv2.COLOR_BGR2RGB).astype(np.float32)
    x   = preprocess_input(rgb)           # MobileNetV2 preprocess
    return np.expand_dims(x, axis=0)      # (1,H,W,3)

def predict_image(path: Path, topk=5):
    bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
    assert bgr is not None, f"cannot read: {path}"
    x = _prep_tensor(bgr, IMG_SIZE)
    probs = model.predict(x, verbose=0)[0]     # (C,)
    order = probs.argsort()[::-1]
    topk = min(topk, len(order))
    return [(CLASSES[i], float(probs[i])) for i in order[:topk]]

def predict_one_sample(class_name: str, idx: int | None = None, topk=5):
    cands = find_images_for_class(class_name)
    assert cands, f"ไม่พบรูปของคลาส '{class_name}' ในโฟลเดอร์ที่กำหนด"
    if idx is None:
        idx = 0
    idx = max(0, min(idx, len(cands)-1))
    img_path = cands[idx]
    res = predict_image(img_path, topk=topk)
    print(f"\nSample: {img_path.name}  (class='{class_name}', idx={idx}, total={len(cands)})")
    for k,(name,score) in enumerate(res,1):
        print(f"  {k:>2d}. {name:>5s}: {score:.3f}")
    print(f"Pred -> {res[0][0]}  conf={res[0][1]:.3f}")
    return res

def predict_many(class_names: list[str], k_per_class=3, shuffle=True, topk=5):
    for cn in class_names:
        cands = find_images_for_class(cn)
        if not cands:
            print(f"[skip] no images for '{cn}'"); 
            continue
        picks = cands if k_per_class is None else cands[:]
        if shuffle:
            random.shuffle(picks)
        picks = picks if k_per_class is None else picks[:k_per_class]
        print(f"\n=== Class: {cn} (testing {len(picks)} images) ===")
        for i, p in enumerate(picks):
            res = predict_image(p, topk=topk)
            print(f"  [{i+1}] {p.name}")
            for k,(name,score) in enumerate(res,1):
                print(f"     {k:>2d}. {name:>5s}: {score:.3f}")
            print(f"     Pred -> {res[0][0]}  conf={res[0][1]:.3f}")

# ===== examples =====
# 1) ทดสอบจุดเดียว (เลือกภาพที่ 0 ของแต่ละคลาส)
predict_one_sample("Empty", idx=0)
predict_one_sample("WP",    idx=0)
predict_one_sample("BP",    idx=0)
predict_one_sample("BN",    idx=0)
predict_one_sample("BB",    idx=0)

# 2) หรือทดสอบหลายรูปต่อคลาส (สุ่ม 3 รูป)
# predict_many(["Empty", "WP", "BP", "BN", "BB"], k_per_class=3, shuffle=True)




classes: ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']

Sample: 2_Move_rotate_student_50.jpg  (class='Empty', idx=0, total=320)
   1.    BP: 0.370
   2.    BR: 0.295
   3.    WN: 0.108
   4.    WR: 0.073
   5. Empty: 0.069
Pred -> BP  conf=0.370

Sample: 040f2bcba5afce3afafdd5bbf36d2ca5_jpg.rf.4b3a8c8430ecaaf5d31ff3b6ff994876_6.jpg  (class='WP', idx=0, total=6690)
   1.    WP: 0.989
   2.    WB: 0.011
   3.    BP: 0.000
   4.    WR: 0.000
   5.    WQ: 0.000
Pred -> WP  conf=0.989

Sample: 03886821377011fec599e8fa12d86e89_jpg.rf.7ec3f29be4f3793b35a2c4a9880d831c_0.jpg  (class='BP', idx=0, total=6948)
   1.    BP: 1.000
   2.    BR: 0.000
   3.    BQ: 0.000
   4.    BN: 0.000
   5.    BB: 0.000
Pred -> BP  conf=1.000

Sample: 03d3ff4582c8125d69c19a72f846bec8_jpg.rf.8cfdbdc73a4c6149758151715b2e8b44_5.jpg  (class='BN', idx=0, total=2014)
   1.    BN: 0.957
   2.    BR: 0.033
   3.    WN: 0.005
   4.    BB: 0.005
   5.    BP: 0.000
Pred -> BN  conf=0.957



[('BB', 0.8839962482452393),
 ('BP', 0.0830816701054573),
 ('BN', 0.0313911959528923),
 ('BQ', 0.0007281983853317797),
 ('BR', 0.0005686705117113888)]