## Inference

In [None]:
import os
from ml import ml_impl_keras  # ensure registrations are loaded
from ml.ml_pipeline import run

cfg = {
    "data": {
        "source": "kaggle_plant_rice",  # adjust if you registered a different loader
        "params": {
            "img_height": 224,
            "img_width": 224,
            "train_split": 0.8,
            "val_split": 0.1,
            "seed": 42,
            "normalize": True,
        },
    },
    "augmenters": [],  # keep empty when using on-the-fly tf.data augmentation
    "model": {
        "builder": "mobilenet_v2_transfer",
        "params": {
            "freeze_base": True,
            "mixed_precision": True,
            "dropout": 0.3,
            "dense_units": 256,
            "dense_l2": 1e-4,
            "classifier_l2": 1e-4,
            "optimizer": "Adam",
            "learning_rate": 3e-4,
            "from_0_1": True,
        },
    },
    "train": {
        "trainer": "tfdata",
        "params": {
            "epochs": 40,
            "batch_size": 32,
            "use_validation": True,
            "monitor": "val_loss",
            "patience": 6,
            "checkpoint_path": "checkpoints/mobilenet_stage1.keras",
            "verbose": 1,
            "shuffle_buffer": 4096,
            "cache": False,
            "cache_val": False,
            "prefetch": True,
            "xla": True,
            "steps_per_epoch": None,
            "validation_steps": None,
            "class_weight": "balanced",
            "reduce_lr_on_plateau": {
                "factor": 0.2,
                "patience": 3,
                "min_lr": 3e-6,
            },
            "augment": {
                "flip_lr": True,
                "flip_ud": False,
                "rotate90": True,
                "brightness": 0.2,
                "contrast": 0.2,
            },
        },
    },
    "evaluate": {
        "evaluator": "basic",
        "params": {},
    },
    "export": [
        {"name": "save_h5", "params": {"path": "exports/mobilenet_stage1.h5"}},
        {"name": "save_saved_model", "params": {"out_dir": "exports/saved_model"}},
        {"name": "save_tfjs", "params": {"out_dir": "exports/tfjs"}},
        {"name": "save_tflite", "params": {"path": "exports/model.tflite", "select_tf_ops": True}},
    ],
}

model, results = run(cfg)
results