# CNN training notebook (clean split)

Run these in order:
1. **Config + imports**
2. **Dataset helpers**
3. **Train** (preload + timing + parameter print)
4. **Tune** (optional)
5. **Evaluate** (optional)

Notes:
- Dataset scanning + image decoding are intentionally in this notebook (not `model.py`).
- `model.py` is reloaded in Cell 2 so changes to the training loop (like `load_arr`) are picked up without restarting the kernel.

In [None]:
# ==========================================================
# 1) Imports + reload model.py + config
# ==========================================================
from pathlib import Path
import importlib
import cv2
import model as cpp_train

# If model.py changed, ensure we use the latest definitions (e.g., train(..., load_arr=...))
cpp_train = importlib.reload(cpp_train)

# ---------------------------
# Config
# ---------------------------
DATA_ROOT = "data_2"
EPOCHS = 14
BATCH_SIZE = 64
VAL_FRAC = 0.2
LR = 1e-3
SEED = 0

# Optional caps
MAX_IMAGES = 0  # 0 = no cap

# Logging/validation cadence (env-driven inside model.py)
LOG_EVERY_SEC = 0        # 0 disables time-based logging
LOG_EVERY_BATCHES = 1    # print every N batches (1 = every batch)
METRICS_EVERY_BATCHES = 1
VAL_EVERY = 1            # validate every N epochs (1 = every epoch)


In [None]:
# ==========================================================
# 2) Dataset helpers (folder-per-class) + image decode
# ==========================================================
def _list_class_dirs(root: Path) -> list[Path]:
    return sorted([p for p in root.iterdir() if p.is_dir()])


def infer_class_mapping(root: Path):
    """Return (class_names, class_to_idx).
    - If all folder names are digits, map by integer value.
    - Otherwise map by lexicographic order."""
    class_dirs = _list_class_dirs(root)
    if not class_dirs:
        raise FileNotFoundError(f"no class folders found under: {root}")

    names = [d.name for d in class_dirs]
    if all(n.isdigit() for n in names):
        idxs = sorted(int(n) for n in names)
        max_idx = max(idxs)
        class_names = ["" for _ in range(max_idx + 1)]
        for n in names:
            class_names[int(n)] = n
        class_to_idx = {name: int(name) for name in names}
        return class_names, class_to_idx

    class_names = sorted(names)
    class_to_idx = {name: i for i, name in enumerate(class_names)}
    return class_names, class_to_idx


def load_dataset_image_label_pairs(root_dir: str | Path, *, max_images: int = 0):
    """Return (items, class_names) where items is [(Path, class_idx), ...]."""
    root = Path(root_dir)
    if not root.exists():
        raise FileNotFoundError(f"dataset root not found: {root}")

    class_names, class_to_idx = infer_class_mapping(root)
    exts = {".png", ".jpg", ".jpeg", ".bmp"}
    items: list[tuple[Path, int]] = []
    for class_dir in _list_class_dirs(root):
        label = int(class_to_idx[class_dir.name])
        for p in class_dir.iterdir():
            if p.is_file() and p.suffix.lower() in exts:
                items.append((p, label))
                if int(max_images) > 0 and len(items) >= int(max_images):
                    return items, class_names

    if not items:
        raise FileNotFoundError(f"no images found under: {root}")
    return items, class_names


def load_image_rgb_u8_hwc_32x32(image_path: str | Path):
    """Load image as 32x32 RGB uint8 HWC using OpenCV."""
    img_bgr = cv2.imread(str(image_path), cv2.IMREAD_COLOR)
    if img_bgr is None:
        raise FileNotFoundError(f"failed to read image: {image_path}")
    h, w = img_bgr.shape[:2]
    interpolation = cv2.INTER_CUBIC if (h < 32 or w < 32) else cv2.INTER_AREA
    resized = cv2.resize(img_bgr, (32, 32), interpolation=interpolation)
    rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
    return rgb


# Model architecture (design + rationale)

**Architecture (forward pass):**
- Input: 32×32 RGB
- 4× Conv(3×3, valid) blocks with BatchNorm + ReLU + Dropout
- MaxPool(2×2) after the 2nd and 4th conv blocks
- Flatten → Dense(512) → Dense(num_classes)

**Why this design fits the assignment:**
- 3×3 valid convolutions keep the model small and fast on 32×32 images.
- BatchNorm stabilizes training with this custom autograd backend.
- Dropout helps regularization and reduces overfitting.
- Two pooling stages reduce spatial size before the fully-connected head.

In [None]:
# ==========================================================
# (Optional) Perf sanity: benchmark conv2d / maxpool2d kernels
# ==========================================================
# This measures the C++ backend performance. If you rebuilt libtensor.dylib
# with -O3, you should see improved timings here.

import time
import importlib

import model as cpp_train
importlib.reload(cpp_train)

from tensor_ctypes import Tensor

def _bench(name, fn, iters=50, warmup=10):
    for _ in range(warmup):
        fn()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn()
    t1 = time.perf_counter()
    dt = (t1 - t0)
    print(f"{name:24s}  {dt/iters*1e3:8.3f} ms/iter  ({iters} iters)")

# ---------------------------
# 1) Isolated conv2d 3x3 valid (hot path in this project)
# ---------------------------
B = 64
x = Tensor.randn((B, 3, 32, 32), requires_grad=False, seed=0)
w1 = Tensor.randn((32, 3, 3, 3), requires_grad=False, seed=1)
b1 = Tensor.zeros((32,), requires_grad=False)

_bench("conv2d 3x3 valid", lambda: x.conv2d(w1, b1, stride=1, padding=0), iters=30, warmup=5)

# ---------------------------
# 2) MaxPool2d 2x2 stride=2 (hot path)
# ---------------------------
y = x.conv2d(w1, b1, stride=1, padding=0).relu()  # (B,32,30,30)
_bench("maxpool2d k2 s2", lambda: y.maxpool2d(kernel=2, stride=2), iters=80, warmup=10)

# ---------------------------
# 3) End-to-end CNN forward (no backward)
# ---------------------------
# Try to reuse NUM_CLASSES if earlier cells ran; otherwise fall back to 10.
if "NUM_CLASSES" not in globals():
    NUM_CLASSES = 10

m_bench = cpp_train.CNN(num_classes=int(NUM_CLASSES))
x2 = Tensor.randn((B, 3, 32, 32), requires_grad=False, seed=123)
_bench("CNN forward", lambda: m_bench.forward_logits(x2, training=False), iters=10, warmup=2)

In [None]:
# ==========================================================
# 3) Model stats: trainable params + MACs/FLOPs per forward
# ==========================================================
from pathlib import Path

# Infer num_classes from folders (does not read images)
class_names, _ = infer_class_mapping(Path(DATA_ROOT))
NUM_CLASSES = len(class_names)
print(f"num_classes inferred from {DATA_ROOT!r}: {NUM_CLASSES}")

m_stats = cpp_train.CNN(num_classes=NUM_CLASSES)

def _numel(shape):
    n = 1
    for d in shape:
        n *= int(d)
    return int(n)

total_params = 0
trainable_params = 0
for name, t in m_stats.named_tensors():
    n = _numel(t.shape())
    total_params += n
    if bool(t.has_grad()):
        trainable_params += n
print(f"Total parameters:     {total_params}")
print(f"Trainable parameters: {trainable_params}")

# ---- MACs / FLOPs (conv + linear only; BN/ReLU/Dropout/pool ignored)
def conv2d_valid_macs(in_h, in_w, in_c, out_c, k=3):
    out_h = in_h - k + 1
    out_w = in_w - k + 1
    if out_h <= 0 or out_w <= 0:
        raise ValueError("invalid conv output shape")
    macs = out_h * out_w * out_c * in_c * k * k
    return out_h, out_w, macs

def linear_macs(in_features, out_features):
    return in_features * out_features

rows = []
h, w, c = 32, 32, 3

h, w, mac = conv2d_valid_macs(h, w, c, 32, k=3)
rows.append(("conv1 3x3", f"{h}x{w}x32", mac))
c = 32
h, w, mac = conv2d_valid_macs(h, w, c, 32, k=3)
rows.append(("conv2 3x3", f"{h}x{w}x32", mac))

# pool2
h, w = h // 2, w // 2

h, w, mac = conv2d_valid_macs(h, w, c, 64, k=3)
rows.append(("conv3 3x3", f"{h}x{w}x64", mac))
c = 64
h, w, mac = conv2d_valid_macs(h, w, c, 64, k=3)
rows.append(("conv4 3x3", f"{h}x{w}x64", mac))

# pool2
h, w = h // 2, w // 2
flatten = c * h * w  # 64*5*5=1600
rows.append(("flatten", f"{flatten}", 0))

mac_fc1 = linear_macs(flatten, 512)
rows.append(("fc1", "512", mac_fc1))
mac_fc2 = linear_macs(512, NUM_CLASSES)
rows.append(("fc2", f"{NUM_CLASSES}", mac_fc2))

total_macs = sum(r[2] for r in rows)
total_flops = 2 * total_macs  # common convention: 1 MAC = 2 FLOPs (mul+add)

print("\nPer-layer MACs (batch=1):")
for name, out_shape, macs in rows:
    print(f"- {name:10s}  out={out_shape:10s}  MACs={macs:,}")
print(f"\nTOTAL MACs (batch=1):  {total_macs:,}")
print(f"TOTAL FLOPs (batch=1): {total_flops:,}")

print("\nNotes:")
print("- MACs/FLOPs above count Conv+Linear only.")
print("- BatchNorm/ReLU/Dropout/Pooling also cost ops but are typically reported separately.")


In [None]:
# ==========================================================
# 4) Train: dataset load time + preload time + per-epoch metrics
# ==========================================================
import os
import random
import time

# ---------------------------
# Env vars consumed by model.py
# ---------------------------
os.environ["EPOCHS"] = str(EPOCHS)
os.environ["BATCH_SIZE"] = str(BATCH_SIZE)
os.environ["LR"] = str(LR)
os.environ["SEED"] = str(SEED)
os.environ["LOG_EVERY_SEC"] = str(LOG_EVERY_SEC)
os.environ["LOG_EVERY_BATCHES"] = str(LOG_EVERY_BATCHES)
os.environ["METRICS_EVERY_BATCHES"] = str(METRICS_EVERY_BATCHES)
os.environ["VAL_EVERY"] = str(VAL_EVERY)

WEIGHTS_OUT = "cnn_weights.pkl"
CLASSES_OUT = "class_names.json"

# ---------------------------
# Dataset scan (paths/labels) time
# ---------------------------
t_scan0 = time.perf_counter()
items, class_names = load_dataset_image_label_pairs(DATA_ROOT, max_images=int(MAX_IMAGES))
t_scan1 = time.perf_counter()
print(f"dataset items: {len(items)}  classes: {len(class_names)}")
print(f"dataset scan time: {t_scan1 - t_scan0:.3f}s")

# ---------------------------
# Split train/val
# ---------------------------
rng = random.Random(int(SEED))
idxs = list(range(len(items)))
rng.shuffle(idxs)
val_n = int(round(len(idxs) * float(VAL_FRAC)))
if float(VAL_FRAC) > 0.0 and len(idxs) >= 2:
    val_n = max(1, val_n)
    val_n = min(val_n, len(idxs) - 1)
val_set = set(idxs[:val_n])
train_items = [items[i] for i in idxs if i not in val_set]
val_items = [items[i] for i in idxs if i in val_set]
print(f"split: train={len(train_items)} val={len(val_items)}")

# ---------------------------
# Preload images + time it
# ---------------------------
t0 = time.perf_counter()
preloaded = {str(p): load_image_rgb_u8_hwc_32x32(p) for p, _ in items}
t1 = time.perf_counter()
dt = t1 - t0
rate = (len(items) / dt) if dt > 0 else 0.0
print(f"preloaded {len(items)} images in {dt:.3f}s  ({rate:.1f} images/sec)")

def load_arr(path):
    return preloaded[str(path)]

# ---------------------------
# Train (collect per-epoch history for plots)
# ---------------------------
m = cpp_train.CNN(num_classes=len(class_names))
history = cpp_train.train(
    m,
    train_items,
    val_items,
    batch_size=int(BATCH_SIZE),
    epochs=int(EPOCHS),
    lr=float(LR),
    seed=int(SEED),
    load_arr=load_arr,
    return_history=True,
 )

# ---------------------------
# Save artifacts
# ---------------------------
cpp_train.save_weights(m, WEIGHTS_OUT)
cpp_train.save_class_names(class_names, CLASSES_OUT)
print("saved weights:", WEIGHTS_OUT)
print("saved class names:", CLASSES_OUT)

# Quick summary key indicators
if history:
    best = max((h for h in history if h.get("val_acc") is not None), key=lambda x: x["val_acc"], default=None)
    if best is not None:
        print(f"best val_acc={best['val_acc']:.4f} at epoch {best['epoch']}")


In [None]:
# ==========================================================
# 5) Plots: loss/accuracy vs epoch (train + val)
# ==========================================================
import math

try:
    import matplotlib.pyplot as plt
except Exception as e:
    raise ImportError("matplotlib is required for plots. Install it (e.g. pip install matplotlib) and re-run.") from e

if not history:
    raise RuntimeError("No training history found. Run the training cell above first.")

epochs = [h["epoch"] for h in history]
train_loss = [h["train_loss"] for h in history]
train_acc = [h["train_acc"] for h in history]
val_loss = [h["val_loss"] if h.get("val_loss") is not None else math.nan for h in history]
val_acc = [h["val_acc"] if h.get("val_acc") is not None else math.nan for h in history]
secs = [h.get("seconds", math.nan) for h in history]
throughput = [ (h.get("train_images", 0) / h["seconds"]) if h.get("seconds") else math.nan for h in history ]

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(epochs, train_loss, label="train_loss")
ax[0].plot(epochs, val_loss, label="val_loss")
ax[0].set_xlabel("epoch")
ax[0].set_ylabel("loss")
ax[0].set_title("Loss")
ax[0].grid(True, alpha=0.3)
ax[0].legend()

ax[1].plot(epochs, train_acc, label="train_acc")
ax[1].plot(epochs, val_acc, label="val_acc")
ax[1].set_xlabel("epoch")
ax[1].set_ylabel("accuracy")
ax[1].set_title("Accuracy")
ax[1].grid(True, alpha=0.3)
ax[1].legend()

plt.tight_layout()
plt.show()

print("Per-epoch timing / throughput:")
for e, s, th in zip(epochs, secs, throughput):
    print(f"epoch {e:3d}  seconds={s:8.3f}  train_img/s={th:8.1f}")


In [None]:
# Tuning cycle: load saved weights + tune on a small subset with augmentation
import os
import random
import time

# Load full dataset (or capped)
items, class_names = load_dataset_image_label_pairs(DATA_ROOT, max_images=int(MAX_IMAGES))
print(f"tuning source size: {len(items)} (MAX_IMAGES={MAX_IMAGES})")

# Pick a subset to tune on
TUNE_EPOCHS = 5
TUNE_BATCH_SIZE = int(BATCH_SIZE)
TUNE_LR = float(LR) * 0.5
TUNE_VAL_FRAC = 0.2
TUNE_IMAGES = 10000

rng = random.Random(123)
idxs = list(range(len(items)))
rng.shuffle(idxs)
if int(TUNE_IMAGES) > 0:
    idxs = idxs[: min(int(TUNE_IMAGES), len(idxs))]
subset = [items[i] for i in idxs]
print(f"tuning subset size: {len(subset)} (TUNE_IMAGES={TUNE_IMAGES})")

# Preload subset (timed)
t0 = time.perf_counter()
preloaded = {str(p): load_image_rgb_u8_hwc_32x32(p) for p, _ in subset}
t1 = time.perf_counter()
print(f"preloaded tuning subset in {t1 - t0:.3f}s")

def load_arr(path):
    return preloaded[str(path)]

# Recreate model + load saved weights
m = cpp_train.CNN(num_classes=len(class_names))
WEIGHTS_PATH = "cnn_weights.pkl"
if not os.path.exists(WEIGHTS_PATH):
    raise FileNotFoundError(f"weights not found: {WEIGHTS_PATH}. Run Cell 2 first.")
cpp_train.load_weights(m, WEIGHTS_PATH)
print("loaded weights:", WEIGHTS_PATH)

# Split subset into tune-train / tune-val
val_n = int(round(len(subset) * float(TUNE_VAL_FRAC)))
val_n = max(1, val_n) if len(subset) >= 2 else 0
tune_val = subset[:val_n]
tune_train = subset[val_n:]
print(f"tune split: train={len(tune_train)} val={len(tune_val)}")

# Make it chatty during tuning
os.environ["LOG_EVERY_SEC"] = "0"
os.environ["LOG_EVERY_BATCHES"] = "1"
os.environ["METRICS_EVERY_BATCHES"] = "1"
os.environ["VAL_EVERY"] = "1"

# Enable augmentation during tuning (read by model.py)
os.environ["AUGMENT"] = "1"
os.environ["AUG_FLIP_PROB"] = "0.5"
os.environ["AUG_TRANSLATE"] = "0.1"
os.environ["AUG_ROTATE_DEG"] = "5"

cpp_train.train(
    m,
    tune_train,
    tune_val,
    batch_size=int(TUNE_BATCH_SIZE),
    epochs=int(TUNE_EPOCHS),
    lr=float(TUNE_LR),
    seed=123,
    load_arr=load_arr,
 )

# Save tuned weights separately
TUNED_OUT = "cnn_weights_tuned_1.pkl"
cpp_train.save_weights(m, TUNED_OUT)
print("saved tuned weights:", TUNED_OUT)


In [None]:
# Evaluate tuned weights on test images
import os
import time

TEST_ROOT = "testing"
TUNED_WEIGHTS = "cnn_weights_tuned_1.pkl"
BATCH_SIZE_EVAL = int(BATCH_SIZE)

if not os.path.exists(TUNED_WEIGHTS):
    raise FileNotFoundError(f"tuned weights not found: {TUNED_WEIGHTS}")

test_items, test_class_names = load_dataset_image_label_pairs(TEST_ROOT, max_images=0)
print(f"test dataset: n={len(test_items)} classes={len(test_class_names)}")

# Preload test set (timed)
t0 = time.perf_counter()
preloaded = {str(p): load_image_rgb_u8_hwc_32x32(p) for p, _ in test_items}
t1 = time.perf_counter()
print(f"preloaded test set in {t1 - t0:.3f}s")

def load_arr(path):
    return preloaded[str(path)]

m_test = cpp_train.CNN(num_classes=len(test_class_names))
cpp_train.load_weights(m_test, TUNED_WEIGHTS)
print("loaded tuned weights:", TUNED_WEIGHTS)

test_loss, test_acc = cpp_train._eval_epoch(
    m_test, test_items, batch_size=max(1, int(BATCH_SIZE_EVAL)), load_arr=load_arr
)

print(f"TEST  loss={test_loss:.6f}  acc={test_acc:.4f}")
