In [None]:
%maven ai.djl:api:0.28.0
%maven ai.djl:basicdataset:0.28.0
%maven ai.djl:model-zoo:0.28.0
%maven ai.djl.mxnet:mxnet-engine:0.28.0
%maven org.slf4j:slf4j-simple:2.0.1

import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.basicmodelzoo.basic.*;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.modality.cv.*;
import ai.djl.modality.Classifications;
import ai.djl.Model;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.training.*;
import ai.djl.training.loss.*;
import ai.djl.training.listener.*;
import ai.djl.training.evaluator.*;
import ai.djl.training.util.*;
import ai.djl.translate.*;
import ai.djl.util.Pair;
import ai.djl.inference.Predictor;
import java.nio.file.*;
import org.slf4j.*;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.util.stream.*;

    int INPUT_SIZE = 28 * 28, // 784
    OUTPUT_SIZE = 10,
    EPOCH_COUNT = 2;
String LOCAL_PATH = "build/mlp";

Image image;
Model model;
Translator<Image, Classifications> translator;
Classifications classifications;

// change in true if need
boolean needCreateTrainingAndSave;

Logger log = LoggerFactory.getLogger(Class.class);

java.awt.image.BufferedImage imageDemo = null;
try {
    URL url = new URL("https://upload.wikimedia.org/wikipedia/commons/2/27/MnistExamples.png");
    imageDemo = ImageIO.read(url);
} catch (Exception e) {
    log.warn("Brake demo Trainer DataSet with exception : " + e.getClass().getName());
}
imageDemo

In [None]:
// Creating, training and save Model
if(needCreateTrainingAndSave) {
    int BATCH_SIZE = 32;
    Mnist basicDataSet = Mnist.builder().setSampling(BATCH_SIZE, true).build();
    try {
        basicDataSet.prepare(new ProgressBar());
    } catch (Exception e) {
        log.warn("Brake DataSet initialisation with exception : " + e.getClass().getName());
    }
    log.info("DataSet initiated");
    // Create Model
    try (Model modelLocal = Model.newInstance("mlp")) {
        modelLocal.setBlock(new Mlp(INPUT_SIZE, OUTPUT_SIZE, new int[] {128, 64}));
        log.info("Model created and to be training");
        // Train Model
        DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
              //softmaxCrossEntropyLoss is a standard loss for classification problems
              .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
              .addTrainingListeners(TrainingListener.Defaults.logging());
        // Now that we have our training configuration, we should create a new trainer for our model
        Trainer trainer = modelLocal.newTrainer(trainingConfig);
        trainer.initialize(new Shape(1, INPUT_SIZE));
        // Deep learning is typically trained in epochs where each epoch trains the model on each item in the dataset once.
        EasyTrain.fit(trainer, EPOCH_COUNT, basicDataSet, null);
        //Save Model
        Path modelDir = Paths.get(LOCAL_PATH);
        Files.createDirectories(modelDir);
        modelLocal.setProperty("Epoch", String.valueOf(EPOCH_COUNT));
        modelLocal.save(modelDir, "mlp");
        log.info("Model trained. Summary:\n" + modelLocal);
    } catch (Exception e) {
        log.warn("Brake training with exception : " + e.getClass().getName());
        throw new RuntimeException(e);
    }
    log.info("Model saved");
}

In [None]:
try {
    image = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
    image.getWrappedImage();
    log.info("Predicting Image downloaded");
    imageDemo = ImageIO.read(new URL("https://resources.djl.ai/images/0.png"));
} catch (Exception e) {
    log.warn("Brake demo Trainer DataSet with exception : " + e.getClass().getName());
}
imageDemo

In [None]:
// With retrieveing trained Model
try (Model model = Model.newInstance("mlp")) {
    model.setBlock(new Mlp(INPUT_SIZE, OUTPUT_SIZE, new int[] {128, 64}));
    model.load(Paths.get(LOCAL_PATH));
    log.info("Model loaded");
    // Create Translator
    translator = new Translator<>() {
        @Override
        public NDList processInput(TranslatorContext context, Image input) throws Exception {
            NDArray array = input.toNDArray(context.getNDManager(), Image.Flag.GRAYSCALE);
            return new NDList(NDImageUtils.toTensor(array));
        }
        @Override
        public Classifications processOutput(TranslatorContext translatorContext, NDList list) throws Exception {
            // Create a Classifications with the output probabilities
            NDArray probabilities = list.singletonOrThrow().softmax(0);
            List<String> classNames = IntStream
                .range(0, 10)
                .mapToObj(String::valueOf)
                .collect(Collectors.toList());
            return new Classifications(classNames, probabilities);
        }
        @Override
        public Batchifier getBatchifier() {
            // The Batchifier describes how to combine a batch together
            // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
            return Batchifier.STACK;
        }
    };
    log.info("Translator created");
    // Predict Image
    Predictor<Image, Classifications> predictor = model.newPredictor(translator);
    classifications = predictor.predict(image);
    log.info("Image predicted");
} catch (Exception e) {
    log.warn("Brake Load Model and Predict Image with exception : " + e.getClass().getName());
}
classifications