# TensorFlow CNN Lab â€” Fashion-MNIST

### Why this notebook
- Provide a guided run-through of the TensorFlow CNN implementation in `../src`.
- Capture the minimal steps needed to train, evaluate, and deploy the model.
- Complement the PyTorch notebook for framework comparison.

### Learning objectives
- Train the CNN using Keras utilities and inspect accuracy/loss metrics.
- Load saved checkpoints and perform predictions on sample images.
- Understand where to customise architecture, optimisers, or augmentations.

### Prerequisites
- TensorFlow 2.x installed (GPU support optional).
- Familiarity with the PyTorch CNN notebook helps with comparisons.
- Optional: Matplotlib for additional visualisations.

### Notebook workflow
1. Import config and helper functions from the TensorFlow `src` directory.
2. Execute `train(CONFIG)` to fit the model while logging metrics.
3. Load the trained weights and call `predict` on Fashion-MNIST samples.
4. Extend with confusion matrices, augmentation strategies, or transfer learning experiments.


**Workflow**

1. Prepare imports and ensure the src package is on `sys.path`.
2. Train the model with `train.train()` (prefers MPS when available).
3. Evaluate predictions from the saved checkpoint.

In [None]:
from pathlib import Path
import sys

NOTEBOOK_DIR = Path().resolve()
SRC_DIR = NOTEBOOK_DIR.parent / 'src'
if str(SRC_DIR) not in sys.path:
    sys.path.append(str(SRC_DIR))

from config import CONFIG  # noqa: E402
from inference import load_model, predict  # noqa: E402
from train import train  # noqa: E402

CONFIG

In [None]:
metrics = train(CONFIG)
metrics

### Interpret the metrics
- Inspect loss and accuracy for both training and validation sets.
- Plot the curves to detect overfitting or optimisation issues.
- Log metrics to TensorBoard by enabling the callback in `train.py`.
- Compare results with the PyTorch notebook to ensure performance parity.

In [None]:
import tensorflow as tf
import numpy as np

(_, _), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
image = test_images[0][:, :, np.newaxis]
label = int(test_labels[0])
model = load_model(config=CONFIG)
prediction = predict([image], model=model)[0]
label, prediction

### Next experiments
- Add data augmentation layers (RandomFlip, RandomRotation) to improve robustness.
- Transfer-learn from a pretrained backbone like MobileNetV2 and compare accuracy.
- Evaluate confusion matrices to diagnose per-class performance.
- Deploy the model via TensorFlow Serving or TFLite for edge inference practice.