# 03 — Train Per‑Cell CNN (MobileNetV2)

Train a 13‑class classifier on patches at `data/final/train|val/<CLASS>/*.jpg`.

- Works on **Kaggle** (`/kaggle/input` + `/kaggle/working`) and **local repo**.
- Saves model to `models/cell_cnn.h5` (or `/kaggle/working/models/cell_cnn.h5`).


In [None]:
# %%capture
# !pip install --quiet tensorflow==2.* opencv-python albumentations tqdm


In [2]:
print(">>> CELL STARTED")

import os, sys, glob
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory # type: ignore
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping # type: ignore
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input # type: ignore

ON_KAGGLE = Path('/kaggle').exists()
ROOT = Path('/kaggle/working') if ON_KAGGLE else Path('..')

TRAIN_DIR = ROOT / 'data/final/train'
VAL_DIR   = ROOT / 'data/final/val'
MODEL_DIR = ROOT / 'models'
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / 'cell_cnn.h5'

IMG_SIZE = (96, 96)
BATCH = 64
EPOCHS = 10

print('Train dir:', TRAIN_DIR)
print('Val   dir:', VAL_DIR)
print('Model   :', MODEL_PATH)


>>> CELL STARTED
Train dir: ..\data\final\train
Val   dir: ..\data\final\val
Model   : ..\models\cell_cnn.h5


In [3]:
# Build datasets (expects class subfolders).
train_ds = image_dataset_from_directory(
    TRAIN_DIR, labels='inferred', label_mode='int',
    image_size=IMG_SIZE, batch_size=BATCH, shuffle=True)
val_ds = image_dataset_from_directory(
    VAL_DIR, labels='inferred', label_mode='int',
    image_size=IMG_SIZE, batch_size=BATCH, shuffle=False)

class_names = train_ds.class_names
num_classes = len(class_names)
print('Classes:', class_names)

# Prefetch + map preprocess_input
AUTOTUNE = tf.data.AUTOTUNE
def prep(x,y):
    return tf.keras.applications.mobilenet_v2.preprocess_input(tf.cast(x, tf.float32)), y
train_ds = train_ds.map(prep).prefetch(AUTOTUNE)
val_ds   = val_ds.map(prep).prefetch(AUTOTUNE)


Found 13632 files belonging to 13 classes.
Found 2069 files belonging to 13 classes.
Classes: ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']


In [4]:
import tensorflow as tf

base = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet'
)
base.trainable = False

inp = tf.keras.layers.Input(IMG_SIZE + (3,))
x = base(inp, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
out = tf.keras.layers.Dense(num_classes, activation='softmax')(x)

model = tf.keras.Model(inp, out)
model.compile(optimizer=tf.keras.optimizers.Adam(5e-4),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.summary()


In [5]:
ckpt = ModelCheckpoint(str(MODEL_PATH), monitor='val_accuracy', save_best_only=True, verbose=1)
es   = EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)
hist = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, callbacks=[ckpt, es])
print('Best model saved to:', MODEL_PATH)


Epoch 1/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 139ms/step - accuracy: 0.5911 - loss: 1.4067
Epoch 1: val_accuracy improved from None to 0.94490, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 180ms/step - accuracy: 0.7780 - loss: 0.7742 - val_accuracy: 0.9449 - val_loss: 0.2568
Epoch 2/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 202ms/step - accuracy: 0.9342 - loss: 0.2671
Epoch 2: val_accuracy improved from 0.94490 to 0.96085, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m51s[0m 236ms/step - accuracy: 0.9442 - loss: 0.2363 - val_accuracy: 0.9609 - val_loss: 0.1787
Epoch 3/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 199ms/step - accuracy: 0.9602 - loss: 0.1733
Epoch 3: val_accuracy improved from 0.96085 to 0.96762, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 226ms/step - accuracy: 0.9624 - loss: 0.1643 - val_accuracy: 0.9676 - val_loss: 0.1472
Epoch 4/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 216ms/step - accuracy: 0.9699 - loss: 0.1391
Epoch 4: val_accuracy improved from 0.96762 to 0.97197, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 255ms/step - accuracy: 0.9709 - loss: 0.1333 - val_accuracy: 0.9720 - val_loss: 0.1278
Epoch 5/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 218ms/step - accuracy: 0.9735 - loss: 0.1198
Epoch 5: val_accuracy improved from 0.97197 to 0.97390, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 252ms/step - accuracy: 0.9743 - loss: 0.1134 - val_accuracy: 0.9739 - val_loss: 0.1157
Epoch 6/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 223ms/step - accuracy: 0.9781 - loss: 0.0981
Epoch 6: val_accuracy improved from 0.97390 to 0.97632, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 256ms/step - accuracy: 0.9779 - loss: 0.0966 - val_accuracy: 0.9763 - val_loss: 0.1055
Epoch 7/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 213ms/step - accuracy: 0.9811 - loss: 0.0870
Epoch 7: val_accuracy improved from 0.97632 to 0.97873, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m53s[0m 246ms/step - accuracy: 0.9820 - loss: 0.0846 - val_accuracy: 0.9787 - val_loss: 0.0993
Epoch 8/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 217ms/step - accuracy: 0.9812 - loss: 0.0770
Epoch 8: val_accuracy improved from 0.97873 to 0.97922, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m53s[0m 249ms/step - accuracy: 0.9823 - loss: 0.0753 - val_accuracy: 0.9792 - val_loss: 0.0929
Epoch 9/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 184ms/step - accuracy: 0.9846 - loss: 0.0676
Epoch 9: val_accuracy did not improve from 0.97922
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 215ms/step - accuracy: 0.9854 - loss: 0.0662 - val_accuracy: 0.9787 - val_loss: 0.0917
Epoch 10/10
[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 235ms/step - accuracy: 0.9861 - loss: 0.0629
Epoch 10: val_accuracy improved from 0.97922 to 0.98115, saving model to ..\models\cell_cnn.h5




[1m213/213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 270ms/step - accuracy: 0.9861 - loss: 0.0618 - val_accuracy: 0.9812 - val_loss: 0.0861
Best model saved to: ..\models\cell_cnn.h5


### (Optional) Fine‑tune
Unfreeze the base for a few epochs if you want a small boost.


In [None]:
# Optional fine‑tune a few layers
unfreeze = False
if unfreeze:
    base.trainable = True
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
                  loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    hist2 = model.fit(train_ds, validation_data=val_ds, epochs=3, callbacks=[ckpt])


In [6]:
print('✅ Done. Model at:', MODEL_PATH)


✅ Done. Model at: ..\models\cell_cnn.h5


In [7]:
# SAVE CLASS ORDER (robust, handles IMG_SIZE int/tuple)
import json
from pathlib import Path
from tensorflow.keras.utils import image_dataset_from_directory # type: ignore

# --- ensure dirs/vars ---
if 'MODEL_DIR' not in locals():
    MODEL_DIR = Path('models')
MODEL_DIR.mkdir(parents=True, exist_ok=True)
CLASSES_JSON = MODEL_DIR / "classes.json"

# --- normalize image_size to a 2-int tuple ---
def _as_hw_tuple(x):
    # x may be int (96) or tuple like (96,96)
    if isinstance(x, (tuple, list)) and len(x) == 2:
        return (int(x[0]), int(x[1]))
    return (int(x), int(x))

IMG_HW = _as_hw_tuple(IMG_SIZE)

# 1) try to read from train_ds
class_names = None
if 'train_ds' in locals():
    try:
        class_names = list(train_ds.class_names)  # only works before map/prefetch
    except Exception:
        class_names = None

# 2) fallback: build a temporary dataset from folder just to get class_names
if class_names is None:
    if not Path(TRAIN_DIR).exists():
        raise FileNotFoundError(f"TRAIN_DIR not found: {TRAIN_DIR}")
    tmp_ds = image_dataset_from_directory(
        TRAIN_DIR,
        labels='inferred',
        label_mode='int',
        image_size=IMG_HW,      # <— use normalized (h,w)
        batch_size=32,
        shuffle=False
    )
    class_names = list(tmp_ds.class_names)

# 3) save to JSON
CLASSES_JSON.write_text(
    json.dumps(class_names, ensure_ascii=False, indent=2),
    encoding="utf-8"
)
print("✅ Saved class order to:", CLASSES_JSON)
print("   class_names =", class_names)


Found 13632 files belonging to 13 classes.
✅ Saved class order to: ..\models\classes.json
   class_names = ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']


In [None]:
# --- Single/Multi-image sanity check (supports Empty, BP/BN/BB/WP/...) ---
from pathlib import Path
import json, random, cv2, numpy as np
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

# ===== paths =====
ROOT         = Path("..").resolve()
MODEL_PATH   = ROOT / "models/cell_cnn.h5"
CLASSES_JSON = ROOT / "models/classes.json"

# ===== load model & class order =====
model   = tf.keras.models.load_model(str(MODEL_PATH))
CLASSES = json.loads(CLASSES_JSON.read_text(encoding="utf-8"))
IMG_SIZE = 96  # ต้องตรงกับตอนเทรน/อินเฟอร์

print("classes:", CLASSES)

# ===== util: find candidate images for a given class =====
SEARCH_DIRS = [
    ROOT / "data/final/val",
    ROOT / "data/final/train",
    ROOT / "data/public/cells",
    ROOT / "data/bootstrap/cells",
]
def find_images_for_class(class_name: str):
    exts = ("*.jpg", "*.png", "*.jpeg")
    files = []
    for base in SEARCH_DIRS:
        p = base / class_name
        if p.exists():
            for ext in exts:
                files += sorted(p.glob(ext))
    return files

# ===== core predict =====
def _prep_tensor(bgr, size=IMG_SIZE):
    rgb = cv2.cvtColor(cv2.resize(bgr, (size, size)), cv2.COLOR_BGR2RGB).astype(np.float32)
    x   = preprocess_input(rgb)           # MobileNetV2 preprocess
    return np.expand_dims(x, axis=0)      # (1,H,W,3)

def predict_image(path: Path, topk=5):
    bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
    assert bgr is not None, f"cannot read: {path}"
    x = _prep_tensor(bgr, IMG_SIZE)
    probs = model.predict(x, verbose=0)[0]     # (C,)
    order = probs.argsort()[::-1]
    topk = min(topk, len(order))
    return [(CLASSES[i], float(probs[i])) for i in order[:topk]]

def predict_one_sample(class_name: str, idx: int | None = None, topk=5):
    cands = find_images_for_class(class_name)
    assert cands, f"ไม่พบรูปของคลาส '{class_name}' ในโฟลเดอร์ที่กำหนด"
    if idx is None:
        idx = 0
    idx = max(0, min(idx, len(cands)-1))
    img_path = cands[idx]
    res = predict_image(img_path, topk=topk)
    print(f"\nSample: {img_path.name}  (class='{class_name}', idx={idx}, total={len(cands)})")
    for k,(name,score) in enumerate(res,1):
        print(f"  {k:>2d}. {name:>5s}: {score:.3f}")
    print(f"Pred -> {res[0][0]}  conf={res[0][1]:.3f}")
    return res

def predict_many(class_names: list[str], k_per_class=3, shuffle=True, topk=5):
    for cn in class_names:
        cands = find_images_for_class(cn)
        if not cands:
            print(f"[skip] no images for '{cn}'"); 
            continue
        picks = cands if k_per_class is None else cands[:]
        if shuffle:
            random.shuffle(picks)
        picks = picks if k_per_class is None else picks[:k_per_class]
        print(f"\n=== Class: {cn} (testing {len(picks)} images) ===")
        for i, p in enumerate(picks):
            res = predict_image(p, topk=topk)
            print(f"  [{i+1}] {p.name}")
            for k,(name,score) in enumerate(res,1):
                print(f"     {k:>2d}. {name:>5s}: {score:.3f}")
            print(f"     Pred -> {res[0][0]}  conf={res[0][1]:.3f}")

# ===== examples =====
# 1) ทดสอบจุดเดียว (เลือกภาพที่ 0 ของแต่ละคลาส)
predict_one_sample("Empty", idx=0)
predict_one_sample("WP",    idx=0)
predict_one_sample("BP",    idx=0)
predict_one_sample("BN",    idx=0)
predict_one_sample("BB",    idx=0)

# 2) หรือทดสอบหลายรูปต่อคลาส (สุ่ม 3 รูป)
# predict_many(["Empty", "WP", "BP", "BN", "BB"], k_per_class=3, shuffle=True)




classes: ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']
Using: 040f2bcba5afce3afafdd5bbf36d2ca5_jpg.rf.4b3a8c8430ecaaf5d31ff3b6ff994876_6.jpg

Top-5:
   WP: 0.990
   WB: 0.010
   BP: 0.000
   WQ: 0.000
   WR: 0.000

Pred: WP conf: 0.9897502064704895
