# 03 — Train Per‑Cell CNN (MobileNetV2)

Train a 13‑class classifier on patches at `data/final/train|val/<CLASS>/*.jpg`.

- Works on **Kaggle** (`/kaggle/input` + `/kaggle/working`) and **local repo**.
- Saves model to `models/cell_cnn.h5` (or `/kaggle/working/models/cell_cnn.h5`).


In [1]:
# %%capture
# !pip install --quiet tensorflow==2.* opencv-python albumentations tqdm


In [2]:
print(">>> CELL STARTED")

import os, sys, glob
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory # type: ignore
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping # type: ignore
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input # type: ignore

ON_KAGGLE = Path('/kaggle').exists()
ROOT = Path('/kaggle/working') if ON_KAGGLE else Path('..')

# NEW: Use balanced dataset instead of old imbalanced dataset
BALANCED_DATA = ROOT / 'data/balanced/cells'
OLD_TRAIN_DIR = ROOT / 'data/final/train'
OLD_VAL_DIR   = ROOT / 'data/final/val'

# Check which dataset to use
if BALANCED_DATA.exists():
    print("✅ Using NEW BALANCED dataset (recommended!)")
    DATA_DIR = BALANCED_DATA
    USE_BALANCED = True
else:
    print("⚠️ Using OLD dataset (imbalanced - Empty cells underrepresented)")
    DATA_DIR = OLD_TRAIN_DIR
    USE_BALANCED = False

MODEL_DIR = ROOT / 'models'
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / 'cell_cnn.h5'

IMG_SIZE = (96, 96)
BATCH = 64
EPOCHS = 10
VAL_SPLIT = 0.2  # Use 20% for validation when using balanced dataset

print('Data dir:', DATA_DIR)
print('Model   :', MODEL_PATH)
print('Using balanced dataset:', USE_BALANCED)

>>> CELL STARTED
✅ Using NEW BALANCED dataset (recommended!)
Data dir: ..\data\balanced\cells
Model   : ..\models\cell_cnn.h5
Using balanced dataset: True


In [3]:
# Build datasets - handles both balanced (single dir) and old (train/val split)
if USE_BALANCED:
    # Balanced dataset: use validation_split
    print(f"Loading balanced dataset with {VAL_SPLIT*100:.0f}% validation split...")
    
    train_ds = image_dataset_from_directory(
        DATA_DIR, 
        labels='inferred', 
        label_mode='int',
        validation_split=VAL_SPLIT,
        subset='training',
        seed=123,
        image_size=IMG_SIZE, 
        batch_size=BATCH, 
        shuffle=True
    )
    
    val_ds = image_dataset_from_directory(
        DATA_DIR, 
        labels='inferred', 
        label_mode='int',
        validation_split=VAL_SPLIT,
        subset='validation',
        seed=123,
        image_size=IMG_SIZE, 
        batch_size=BATCH, 
        shuffle=False
    )
else:
    # Old dataset: separate train/val directories
    print("Loading from separate train/val directories...")
    
    train_ds = image_dataset_from_directory(
        OLD_TRAIN_DIR, 
        labels='inferred', 
        label_mode='int',
        image_size=IMG_SIZE, 
        batch_size=BATCH, 
        shuffle=True
    )
    
    val_ds = image_dataset_from_directory(
        OLD_VAL_DIR, 
        labels='inferred', 
        label_mode='int',
        image_size=IMG_SIZE, 
        batch_size=BATCH, 
        shuffle=False
    )

class_names = train_ds.class_names
num_classes = len(class_names)
print('Classes:', class_names)
print(f'Total classes: {num_classes}')

# Prefetch + map preprocess_input
AUTOTUNE = tf.data.AUTOTUNE
def prep(x,y):
    return tf.keras.applications.mobilenet_v2.preprocess_input(tf.cast(x, tf.float32)), y
train_ds = train_ds.map(prep).prefetch(AUTOTUNE)
val_ds   = val_ds.map(prep).prefetch(AUTOTUNE)

print("✅ Dataset loaded and preprocessed!")

Loading balanced dataset with 20% validation split...
Found 16196 files belonging to 13 classes.
Using 12957 files for training.
Found 16196 files belonging to 13 classes.
Using 3239 files for validation.
Classes: ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']
Total classes: 13
✅ Dataset loaded and preprocessed!


In [5]:
import tensorflow as tf

base = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet'
)
base.trainable = False

inp = tf.keras.layers.Input(IMG_SIZE + (3,))
x = base(inp, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
out = tf.keras.layers.Dense(num_classes, activation='softmax')(x)

model = tf.keras.Model(inp, out)
model.compile(optimizer=tf.keras.optimizers.Adam(5e-4),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.summary()


In [6]:
# Compute class weights to fix Empty cell bias
import numpy as np
from sklearn.utils.class_weight import compute_class_weight

print("Computing class weights to balance training...")

# Count samples per class
class_counts = {}
for i, class_name in enumerate(class_names):
    class_path = DATA_DIR / class_name
    if class_path.exists():
        count = len(list(class_path.glob("*.jpg")))
        class_counts[i] = count
        print(f"  {class_name:>5}: {count:4d} images")

# Compute balanced weights
class_indices = list(class_counts.keys())
class_samples = list(class_counts.values())

# Create array of class labels proportional to their frequency
y_train = np.repeat(class_indices, class_samples)

# Compute weights
weights = compute_class_weight('balanced', classes=np.array(class_indices), y=y_train)
class_weight = {i: w for i, w in zip(class_indices, weights)}

print("\nComputed class weights:")
for i, class_name in enumerate(class_names):
    print(f"  {class_name:>5}: {class_weight[i]:.3f}")

# Highlight Empty weight
empty_idx = class_names.index('Empty')
print(f"\n✅ Empty class weight: {class_weight[empty_idx]:.3f}")
print("   (Higher weight = model penalized more for Empty errors)")

# Train with class weights
print("\n" + "="*60)
print("Starting training with class weights...")
print("="*60)

ckpt = ModelCheckpoint(str(MODEL_PATH), monitor='val_accuracy', save_best_only=True, verbose=1)
es   = EarlyStopping(monitor='val_accuracy', patience=3, restore_best_weights=True)

hist = model.fit(
    train_ds, 
    validation_data=val_ds, 
    epochs=EPOCHS, 
    callbacks=[ckpt, es],
    class_weight=class_weight  # ← KEY FIX: Use class weights!
)

print('Best model saved to:', MODEL_PATH)
print("\n✅ Training complete with class weight balancing!")

Computing class weights to balance training...
     BB:  684 images
     BK:  705 images
     BN:  973 images
     BP: 3361 images
     BQ:  430 images
     BR: 1007 images
  Empty: 1784 images
     WB:  860 images
     WK:  725 images
     WN:  944 images
     WP: 3236 images
     WQ:  552 images
     WR:  935 images

Computed class weights:
     BB: 1.821
     BK: 1.767
     BN: 1.280
     BP: 0.371
     BQ: 2.897
     BR: 1.237
  Empty: 0.698
     WB: 1.449
     WK: 1.718
     WN: 1.320
     WP: 0.385
     WQ: 2.257
     WR: 1.332

✅ Empty class weight: 0.698
   (Higher weight = model penalized more for Empty errors)

Starting training with class weights...
Epoch 1/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 132ms/step - accuracy: 0.5606 - loss: 1.5254
Epoch 1: val_accuracy improved from None to 0.93856, saving model to ..\models\cell_cnn.h5




[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 184ms/step - accuracy: 0.7536 - loss: 0.9174 - val_accuracy: 0.9386 - val_loss: 0.3182
Epoch 2/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 155ms/step - accuracy: 0.9286 - loss: 0.3703
Epoch 2: val_accuracy improved from 0.93856 to 0.95709, saving model to ..\models\cell_cnn.h5




[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 194ms/step - accuracy: 0.9333 - loss: 0.3330 - val_accuracy: 0.9571 - val_loss: 0.1873
Epoch 3/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 150ms/step - accuracy: 0.9461 - loss: 0.2530
Epoch 3: val_accuracy improved from 0.95709 to 0.95801, saving model to ..\models\cell_cnn.h5




[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 188ms/step - accuracy: 0.9491 - loss: 0.2397 - val_accuracy: 0.9580 - val_loss: 0.1731
Epoch 4/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 149ms/step - accuracy: 0.9586 - loss: 0.2040
Epoch 4: val_accuracy improved from 0.95801 to 0.96511, saving model to ..\models\cell_cnn.h5




[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 189ms/step - accuracy: 0.9580 - loss: 0.1980 - val_accuracy: 0.9651 - val_loss: 0.1444
Epoch 5/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 150ms/step - accuracy: 0.9632 - loss: 0.1680
Epoch 5: val_accuracy improved from 0.96511 to 0.96882, saving model to ..\models\cell_cnn.h5




[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m38s[0m 188ms/step - accuracy: 0.9633 - loss: 0.1693 - val_accuracy: 0.9688 - val_loss: 0.1189
Epoch 6/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 156ms/step - accuracy: 0.9697 - loss: 0.1439
Epoch 6: val_accuracy improved from 0.96882 to 0.97098, saving model to ..\models\cell_cnn.h5




[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 198ms/step - accuracy: 0.9675 - loss: 0.1418 - val_accuracy: 0.9710 - val_loss: 0.1062
Epoch 7/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 160ms/step - accuracy: 0.9704 - loss: 0.1316
Epoch 7: val_accuracy improved from 0.97098 to 0.97499, saving model to ..\models\cell_cnn.h5




[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 201ms/step - accuracy: 0.9719 - loss: 0.1282 - val_accuracy: 0.9750 - val_loss: 0.0871
Epoch 8/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 159ms/step - accuracy: 0.9751 - loss: 0.1127
Epoch 8: val_accuracy did not improve from 0.97499
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 207ms/step - accuracy: 0.9749 - loss: 0.1138 - val_accuracy: 0.9734 - val_loss: 0.0985
Epoch 9/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 158ms/step - accuracy: 0.9757 - loss: 0.1054
Epoch 9: val_accuracy improved from 0.97499 to 0.97808, saving model to ..\models\cell_cnn.h5




[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 199ms/step - accuracy: 0.9757 - loss: 0.1043 - val_accuracy: 0.9781 - val_loss: 0.0799
Epoch 10/10
[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 147ms/step - accuracy: 0.9780 - loss: 0.0955
Epoch 10: val_accuracy improved from 0.97808 to 0.98178, saving model to ..\models\cell_cnn.h5




[1m203/203[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 182ms/step - accuracy: 0.9779 - loss: 0.0957 - val_accuracy: 0.9818 - val_loss: 0.0641
Best model saved to: ..\models\cell_cnn.h5

✅ Training complete with class weight balancing!


### (Optional) Fine‑tune
Unfreeze the base for a few epochs if you want a small boost.


In [None]:
# Optional fine‑tune a few layers
unfreeze = False
if unfreeze:
    base.trainable = True
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
                  loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    hist2 = model.fit(train_ds, validation_data=val_ds, epochs=3, callbacks=[ckpt])


In [7]:
print('✅ Done. Model at:', MODEL_PATH)


✅ Done. Model at: ..\models\cell_cnn.h5


In [8]:
# SAVE CLASS ORDER (robust, handles IMG_SIZE int/tuple)
import json
from pathlib import Path
from tensorflow.keras.utils import image_dataset_from_directory # type: ignore

# --- ensure dirs/vars ---
if 'MODEL_DIR' not in locals():
    MODEL_DIR = Path('models')
MODEL_DIR.mkdir(parents=True, exist_ok=True)
CLASSES_JSON = MODEL_DIR / "classes.json"

# --- normalize image_size to a 2-int tuple ---
def _as_hw_tuple(x):
    # x may be int (96) or tuple like (96,96)
    if isinstance(x, (tuple, list)) and len(x) == 2:
        return (int(x[0]), int(x[1]))
    return (int(x), int(x))

IMG_HW = _as_hw_tuple(IMG_SIZE)

# 1) try to read from train_ds
class_names = None
if 'train_ds' in locals():
    try:
        class_names = list(train_ds.class_names)  # only works before map/prefetch
    except Exception:
        class_names = None

# 2) fallback: build a temporary dataset from folder just to get class_names
if class_names is None:
    # Use DATA_DIR (works with both balanced and old dataset)
    if not Path(DATA_DIR).exists():
        raise FileNotFoundError(f"DATA_DIR not found: {DATA_DIR}")
    tmp_ds = image_dataset_from_directory(
        DATA_DIR,
        labels='inferred',
        label_mode='int',
        image_size=IMG_HW,      # <— use normalized (h,w)
        batch_size=32,
        shuffle=False
    )
    class_names = list(tmp_ds.class_names)

# 3) save to JSON
CLASSES_JSON.write_text(
    json.dumps(class_names, ensure_ascii=False, indent=2),
    encoding="utf-8"
)
print("✅ Saved class order to:", CLASSES_JSON)
print("   class_names =", class_names)

Found 16196 files belonging to 13 classes.
✅ Saved class order to: ..\models\classes.json
   class_names = ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']


In [9]:
# --- Single/Multi-image sanity check (supports Empty, BP/BN/BB/WP/...) ---
from pathlib import Path
import json, random, cv2, numpy as np
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input

# ===== paths =====
ROOT         = Path("..").resolve()
MODEL_PATH   = ROOT / "models/cell_cnn.h5"
CLASSES_JSON = ROOT / "models/classes.json"

# ===== load model & class order =====
model   = tf.keras.models.load_model(str(MODEL_PATH))
CLASSES = json.loads(CLASSES_JSON.read_text(encoding="utf-8"))
IMG_SIZE = 96  # must match training size

print("classes:", CLASSES)

# ===== util: find candidate images for a given class =====
# Search in BALANCED dataset FIRST, then fall back to old dataset
SEARCH_DIRS = [
    ROOT / "data/balanced/cells",      # NEW balanced dataset (priority)
    ROOT / "data/final/val",           # Old validation set
    ROOT / "data/final/train",         # Old training set
    ROOT / "data/public/cells",
    ROOT / "data/bootstrap/cells",
]

def find_images_for_class(class_name: str):
    exts = ("*.jpg", "*.png", "*.jpeg")
    files = []
    for base in SEARCH_DIRS:
        p = base / class_name
        if p.exists():
            for ext in exts:
                files += sorted(p.glob(ext))
    return files

# ===== core predict =====
def _prep_tensor(bgr, size=IMG_SIZE):
    rgb = cv2.cvtColor(cv2.resize(bgr, (size, size)), cv2.COLOR_BGR2RGB).astype(np.float32)
    x   = preprocess_input(rgb)           # MobileNetV2 preprocess
    return np.expand_dims(x, axis=0)      # (1,H,W,3)

def predict_image(path: Path, topk=5):
    bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
    assert bgr is not None, f"cannot read: {path}"
    x = _prep_tensor(bgr, IMG_SIZE)
    probs = model.predict(x, verbose=0)[0]     # (C,)
    order = probs.argsort()[::-1]
    topk = min(topk, len(order))
    return [(CLASSES[i], float(probs[i])) for i in order[:topk]]

def predict_one_sample(class_name: str, idx: int | None = None, topk=5):
    cands = find_images_for_class(class_name)
    assert cands, f"No images found for class '{class_name}' in search directories"
    if idx is None:
        idx = 0
    idx = max(0, min(idx, len(cands)-1))
    img_path = cands[idx]
    res = predict_image(img_path, topk=topk)
    print(f"\nSample: {img_path.name}  (class='{class_name}', idx={idx}, total={len(cands)})")
    for k,(name,score) in enumerate(res,1):
        print(f"  {k:>2d}. {name:>5s}: {score:.3f}")
    
    # Check if prediction is correct
    correct = "✅ CORRECT" if res[0][0] == class_name else "❌ WRONG"
    print(f"Pred -> {res[0][0]}  conf={res[0][1]:.3f}  {correct}")
    return res

def predict_many(class_names: list[str], k_per_class=3, shuffle=True, topk=5):
    total_tested = 0
    total_correct = 0
    
    for cn in class_names:
        cands = find_images_for_class(cn)
        if not cands:
            print(f"[skip] no images for '{cn}'"); 
            continue
        picks = cands if k_per_class is None else cands[:]
        if shuffle:
            random.shuffle(picks)
        picks = picks if k_per_class is None else picks[:k_per_class]
        print(f"\n=== Class: {cn} (testing {len(picks)} images) ===")
        
        class_correct = 0
        for i, p in enumerate(picks):
            res = predict_image(p, topk=topk)
            is_correct = res[0][0] == cn
            if is_correct:
                class_correct += 1
                total_correct += 1
            total_tested += 1
            
            correct_mark = "✅" if is_correct else "❌"
            print(f"  [{i+1}] {p.name} {correct_mark}")
            for k,(name,score) in enumerate(res,1):
                print(f"     {k:>2d}. {name:>5s}: {score:.3f}")
            print(f"     Pred -> {res[0][0]}  conf={res[0][1]:.3f}")
        
        print(f"  Class accuracy: {class_correct}/{len(picks)} = {class_correct/len(picks)*100:.1f}%")
    
    if total_tested > 0:
        print(f"\n{'='*60}")
        print(f"OVERALL ACCURACY: {total_correct}/{total_tested} = {total_correct/total_tested*100:.1f}%")
        print(f"{'='*60}")

# ===== examples =====
print("\n" + "="*60)
print("TESTING MODEL ON SAMPLE IMAGES")
print("="*60)

# 1) Test individual samples
predict_one_sample("Empty", idx=0)
predict_one_sample("WP",    idx=0)
predict_one_sample("BP",    idx=0)
predict_one_sample("BN",    idx=0)
predict_one_sample("BB",    idx=0)

# 2) Comprehensive test on Empty class (the problematic one)
print("\n" + "="*60)
print("COMPREHENSIVE TEST: Empty class (3 random samples)")
print("="*60)
predict_many(["Empty"], k_per_class=3, shuffle=True)

# 3) Optional: test all classes
# predict_many(CLASSES, k_per_class=2, shuffle=True)



classes: ['BB', 'BK', 'BN', 'BP', 'BQ', 'BR', 'Empty', 'WB', 'WK', 'WN', 'WP', 'WQ', 'WR']

TESTING MODEL ON SAMPLE IMAGES

Sample: 2_Move_rotate_student_20.jpg  (class='Empty', idx=0, total=3773)
   1. Empty: 0.740
   2.    BP: 0.110
   3.    WP: 0.108
   4.    BN: 0.015
   5.    WB: 0.011
Pred -> Empty  conf=0.740  ✅ CORRECT

Sample: 0301b7f9ed4d5ba503fda79fc4370c29_jpg.rf.56da1174519560712119d3fc195068cb_0.jpg  (class='WP', idx=0, total=9708)
   1.    WP: 0.987
   2.    WN: 0.009
   3.    WB: 0.003
   4.    BP: 0.001
   5.    WR: 0.001
Pred -> WP  conf=0.987  ✅ CORRECT

Sample: 03886821377011fec599e8fa12d86e89_jpg.rf.44fb00bcea92435e28c1ea1a89595b32_0.jpg  (class='BP', idx=0, total=10083)
   1.    BP: 0.997
   2.    BN: 0.002
   3.    BR: 0.001
   4.    BB: 0.000
   5.    WN: 0.000
Pred -> BP  conf=0.997  ✅ CORRECT

Sample: 03d3ff4582c8125d69c19a72f846bec8_jpg.rf.0abd6e8d01091ac5396f7a9cf390bdc9_4.jpg  (class='BN', idx=0, total=2919)
   1.    BN: 0.946
   2.    WN: 0.028
   3.    BB