## Inference

In [None]:
import os
from ml import ml_impl_keras  # ensure registrations are loaded
from ml.ml_pipeline import run

cfg = {
    "data": {
        "source": "kaggle_potato_disease",
        "params": {
            "img_height": 160,  # Adjust as needed
            "img_width": 160,
            "color_mode": "rgb",
            "normalize": True,
            "subset_fraction": 0.4,
            "max_images_per_class": 50,
            "val_split": 0.15,
            "test_split": 0.15,
            "random_state": 42,
        }
    },
    #"data": {
    #    "source": "keras_mnist",
    #    "params": {}
    #},
    #"data": {
    #    "source": "from_directory",
    #    "params": {
    #        "train_dir": "./data/train",
    #        "val_dir": "./data/val",
    #        "test_dir": "./data/test",
    #        "img_height": 28,
    #        "img_width": 28,
    #        "color_mode": "grayscale",
    #        "normalize": True,
    #    },
    #},
    "augmenters": [],
    "model": {
        "builder": "mobilenet_v2_transfer",
        "params": {
            "freeze_base": True,
            "mixed_precision": True,
            "dropout": 0.2,
            "dense_units": 128,
            "optimizer": "Adam",
            "learning_rate": 0.0005,
            "from_0_1": True,
        },
    },
    "train": {
        "trainer": "tfdata",
        "params": {
            "epochs": 1,
            "batch_size": 64,
            "use_validation": True,
            "monitor": "val_loss",
            "patience": 1,
            "checkpoint_path": "best_model.h5",
            "verbose": 1,
            "shuffle_buffer": 2000,
            "cache": True,
            "prefetch": True,
            "xla": True,
            "steps_per_epoch": 50,
            "validation_steps": 10,
            "augment": {
                "flip_lr": True,
                "flip_ud": False,
                "rotate90": False,
                "brightness": 0.05,
                "contrast": 0.05,
            },
        },
    },
    "evaluate": {"evaluator": "basic", "params": {}},
    "export": [
        {"name": "save_h5", "params": {"path": "best_model.h5"}},
        {"name": "save_saved_model", "params": {"out_dir": "saved_model"}},
        {"name": "save_tfjs", "params": {"out_dir": "tfjs_model"}},
        {"name": "save_tflite", "params": {"path": "tflite/model.tflite"}},
    ],
}

model, results = run(cfg)
results