# Classification
The task is to identify which category an object belongs to.

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

### Let's try to create an ML model for recognizing handwritten numbers. We download the data first.

In [None]:
# Import MNIST data
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

### Now we can see what our task is. From the small greyscale pictures, decide what number is written on it. The correct answer is marked here as `target`. Since we know the correct answer, we can used some method based on the supervised learning.

In [None]:
plt.figure(figsize=(10, 10))
for idx in range(25):
    plt.subplot(5, 5, idx+1)
    fig = plt.imshow(x_train[idx], cmap='gray')
    plt.title('Target {}'.format(y_train[idx]))
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

### We will use the Neural Network. First we create a simple neural network a model, here using the Tensorflow library.

In [None]:
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),
                                    tf.keras.layers.Dense(256),
                                    tf.keras.layers.Dropout(0.25),
                                    tf.keras.layers.Dense(10, activation='softmax')
                                   ])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


### We now train the created model on the training data and then evaluate it on the test data.

In [None]:
model.fit(x_train, y_train, epochs=20, batch_size=512, shuffle=True)

model.evaluate(x_test,  y_test, verbose=2)

### Now we can see the resulting classification, the caption above the image we are trying to classify is in the format {correct answer} - {prediction}

In [None]:
plt.figure(figsize=(10, 10))

start = 100
for idx in range(25):
    plt.subplot(5,5,idx+1)
    fig = plt.imshow(x_test[idx+start], cmap='gray')
#     print(model.predict(x_test[idx+start].reshape(1, 28, 28)), np.argmax(model.predict(x_test[idx+start].reshape(1, 28, 28))))
    plt.title('{} - {}'.format(y_test[idx+start], np.argmax(model.predict(x_test[idx+start].reshape(1, 28, 28)))))
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)

### In the figure below we can isnpect a few incorrect predictions

In [None]:
ypredict = np.argmax(model.predict(x_test), axis=1)
wrong = np.where(ypredict-y_test)[0]
plt.figure(figsize=(10, 10))


start = 100
for idx in range(25):
    plt.subplot(5,5,idx+1)
    fig = plt.imshow(x_test[wrong[idx]], cmap='gray')
    plt.title('{} - {}'.format(y_test[wrong[idx]], ypredict[wrong[idx]]))
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)


In [None]:
from IPython.display import HTML, display
from textwrap import dedent
import base64, io, numpy as np
from PIL import Image
from google.colab import output
import ipywidgets as W
import tensorflow as tf

display(HTML(dedent("""
<style>
#draw-wrap { display: inline-block; user-select: none; }
#draw-toolbar { margin-bottom: 8px; display: flex; gap: 8px; align-items: center; }
#draw-canvas { border: 1px solid #ccc; touch-action: none; background: white; }
#pred-output { font-family: ui-monospace, monospace; margin-top: 8px; }
</style>
<div id="draw-wrap">
  <div id="draw-toolbar">
    <label>Brush size: <input id="brush" type="range" min="3" max="40" step="1" value="18"></label>
    <button id="eraser">Eraser</button>
    <button id="pen" disabled>Pen</button>
    <button id="clear">Clear</button>
  </div>
  <canvas id="draw-canvas" width="280" height="280"></canvas>
  <div id="pred-output">Draw a digit (0–9). Then run the prediction cell.</div>
</div>
<script>
(() => {
  const canvas = document.getElementById('draw-canvas');
  const ctx = canvas.getContext('2d');
  const brush = document.getElementById('brush');
  const btnEraser = document.getElementById('eraser');
  const btnPen = document.getElementById('pen');
  const btnClear = document.getElementById('clear');
  const out = document.getElementById('pred-output');

  // white background
  ctx.fillStyle = 'white';
  ctx.fillRect(0, 0, canvas.width, canvas.height);

  let drawing = false;
  let last = null;
  let erasing = false;

  const getPos = (e) => {
    const r = canvas.getBoundingClientRect();
    if (e.touches && e.touches[0]) {
      return { x: e.touches[0].clientX - r.left, y: e.touches[0].clientY - r.top };
    }
    return { x: e.clientX - r.left, y: e.clientY - r.top };
  };

  const drawLine = (x0, y0, x1, y1, w) => {
    ctx.lineWidth = w;
    ctx.lineCap = 'round';
    ctx.strokeStyle = erasing ? 'white' : 'black';
    ctx.beginPath();
    ctx.moveTo(x0, y0);
    ctx.lineTo(x1, y1);
    ctx.stroke();
  };

  const start = (e) => { drawing = true; last = getPos(e); e.preventDefault(); };
  const move  = (e) => {
    if (!drawing) return;
    const p = getPos(e);
    drawLine(last.x, last.y, p.x, p.y, parseInt(brush.value));
    last = p;
    e.preventDefault();
  };
  const end   = (e) => { drawing = false; last = null; e.preventDefault(); };

  canvas.addEventListener('mousedown', start);
  canvas.addEventListener('mousemove', move);
  canvas.addEventListener('mouseup', end);
  canvas.addEventListener('mouseleave', end);
  canvas.addEventListener('touchstart', start, {passive:false});
  canvas.addEventListener('touchmove', move, {passive:false});
  canvas.addEventListener('touchend', end);

  btnClear.onclick = () => {
    ctx.fillStyle = 'white';
    ctx.fillRect(0,0,canvas.width, canvas.height);
    out.textContent = "Canvas cleared.";
  };
  btnEraser.onclick = () => { erasing = true; btnEraser.disabled = true; btnPen.disabled = false; };
  btnPen.onclick = () => { erasing = false; btnPen.disabled = true; btnEraser.disabled = false; };

  // functions available to Python via eval_js
  window.__mnist_get_png = () => canvas.toDataURL('image/png');
  window.__mnist_message = (txt) => { out.textContent = txt; };
})();
</script>
""")))

INVERT = True
SIZE = (28, 28)

# --- UI elements ---
lbl = W.HTML(value="Ready.")
btn = W.Button(description="Predict", button_style="primary")
stats = W.HTML(value="")
img_small = W.HTML()
img_large = W.HTML()

box_preview = W.VBox([
    W.HTML("<b>Network input (28×28, grayscale)</b>"),
    W.HBox([img_small, img_large]),
    stats
])
display(W.HBox([btn, lbl]))
display(box_preview)

def _to_png_bytes(arr_2d_uint8: np.ndarray) -> bytes:
    im = Image.fromarray(arr_2d_uint8, mode='L')
    bio = io.BytesIO()
    im.save(bio, format='PNG')
    return bio.getvalue()

def preprocess_from_canvas() -> tuple[np.ndarray, np.ndarray]:
    """
    Returns:
      - a2d: float32 (28,28) in [0,1] – for the model
      - a8:  uint8 (28,28) in [0,255] – for previews
    """
    data_url = output.eval_js("__mnist_get_png()")
    b64 = data_url.split(",", 1)[1]
    img = Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")

    arr = np.array(img, dtype=np.float32)       # (H,W,3)
    gray = arr.mean(axis=2)                     # (H,W) [0..255]
    if INVERT:
        gray = 255.0 - gray
    im28 = Image.fromarray(gray.astype(np.uint8)).resize(SIZE, Image.LANCZOS)

    a8 = np.asarray(im28, dtype=np.uint8)       # (28,28) uint8 [0..255]
    a2d = a8.astype(np.float32) / 255.0         # (28,28) float32 [0,1]
    return a2d, a8

def _normalize_input_shape(shape):
    if isinstance(shape, (list, tuple)) and shape and isinstance(shape[0], (list, tuple)):
        return shape[0]
    return shape

def make_batch(a2d: np.ndarray) -> np.ndarray:
    insh = _normalize_input_shape(getattr(model, "input_shape", None))
    if insh is None:
        x = a2d[np.newaxis, ...]                            # (1,28,28)
    else:
        if len(insh) == 3:                                  # (None, H, W)
            x = a2d[np.newaxis, ...]                        # (1,28,28)
        elif len(insh) == 4:                                # (None, H, W, C)
            c = insh[-1] if isinstance(insh[-1], int) else 1
            if c == 1:
                x = a2d[..., np.newaxis][np.newaxis, ...]   # (1,28,28,1)
            else:
                x = np.repeat(a2d[..., np.newaxis], c, axis=-1)[np.newaxis, ...]
        else:
            x = a2d[np.newaxis, ...]
    return np.asarray(x, dtype=np.float32)

def on_click(_):
    try:
        a2d, a8 = preprocess_from_canvas()      # (28,28) float32 and (28,28) uint8
        png = _to_png_bytes(a8)
        b64 = base64.b64encode(png).decode("ascii")

        # pixel-perfect previews
        img_small.value = f'<img src="data:image/png;base64,{b64}" style="width:28px;height:28px;image-rendering: pixelated;border:1px solid #ccc;margin-right:12px;" />'
        img_large.value = f'<img src="data:image/png;base64,{b64}" style="width:280px;height:280px;image-rendering: pixelated;border:1px solid #ccc;" />'

        stats.value = (
            f"<tt>array shape: (28,28), dtype=float32, "
            f"min={a2d.min():.3f}, max={a2d.max():.3f}, mean={a2d.mean():.3f}</tt>"
        )

        x = make_batch(a2d)
        y = model.predict(x, verbose=0)[0].astype(np.float32)
        pred = int(y.argmax())
        top3 = y.argsort()[-3:][::-1]
        msg = "Prediction: <b>{}</b> | Top-3: {}".format(
            pred, ", ".join(f"{i}: {y[i]:.3f}" for i in top3)
        )
        lbl.value = msg
        output.eval_js(f'__mnist_message("{msg}")')
    except Exception as e:
        msg = f"<span style='color:#b00'>Prediction error: {e}</span>"
        lbl.value = msg
        s = str(e).replace("\\", "/").replace('"', '\\"')
        output.eval_js(f'__mnist_message("Prediction error: {s}")')

btn.on_click(on_click)
