# Handwritten Japanese Character Recognition (KMNIST Demo)

This tool lets you **try out an AI model** that recognizes *handwritten Japanese characters* from the **KMNIST dataset**.  
Don’t worry — no programming knowledge or Japanese background is needed to use it.

---

## 🤖 What kind of AI is this?
This demo uses **Artificial Intelligence (AI)** — more specifically a branch called **Deep Learning**.

- **AI** is the general idea: teaching computers to do tasks that usually need human intelligence.  
- **Machine Learning** is a part of AI: instead of being programmed with rules, the computer *learns* from data.  
- **Deep Learning** is a special kind of Machine Learning: it uses many “layers” of artificial **neural networks** (inspired by the brain) to find complex patterns in data.  

Here we apply **Deep Learning for Image Recognition**:
- The computer looks at your image (just 28×28 black-and-white dots).  
- Through many layers, it learns to recognize strokes, curves, and shapes.  
- Finally, it predicts which Japanese character your drawing most likely represents.

---

## 🔎 What is KMNIST?
- KMNIST is a collection of **70,000 tiny images (28×28 pixels each)**.  
- Each image shows **one Japanese Hiragana character** (basic sounds of Japanese).  
- In this demo, the model is trained to recognize **10 different characters**:  
  `o, ki, su, tsu, na, ha, ma, ya, re, wo`

We also show you:
- The character in **Hiragana script** (example: `き`)  
- The **English sound** (example: `"kee"`)  

---

## ✍️ How can you use this demo?
You can try three ways:

1. **Upload**  
   - Upload any image.  
   - The AI resizes it to 28×28 pixels and predicts what character it looks like.

2. **Draw**  
   - Use your mouse or trackpad to sketch inside the white box.  
   - Click *Predict from drawing* to see what the AI thinks.

3. **Sample**  
   - Pick one of the 10 characters.  
   - The app shows you a real dataset example and predicts.

---

## 📊 What do you see after prediction?
- **Top-1 prediction**: The AI’s best guess, with probability and meaning.  
- **Top-3 predictions**: Other close guesses.  
- **Probability chart**: A bar graph of confidence levels.  
- **Your 28×28 image**: Exactly what the model saw (click to zoom).  
- **Reference grid**: Real dataset examples for comparison.  
- **Sampled test image**: A random example (only in the *Sample* tab).

---

## 🎯 Why is this interesting?
This demo is a small but powerful example of **Deep Learning in action**:
- It shows how AI can **read handwriting**, even from tiny low-resolution images.  
- The same technology (deep learning image recognition) powers many real-world tools:  
  - Phone face unlock  
  - Self-driving cars (recognizing roads, signs, pedestrians)  
  - Medical imaging (detecting tumors in X-rays or MRI scans)  
  - Translating handwritten text into digital form  

Even if you don’t know Japanese, this demo is a fun way to see **how modern AI learns and makes predictions**.

---

## ⚠️ A Note on Accuracy
This demo shows how **AI (Deep Learning)** can recognize handwritten characters, but it’s not always perfect.  

Why mistakes can happen:
- **Training data**: The model has only seen a limited number of handwritten examples (28×28 pixel images). If your handwriting is very different, it may confuse the model.  
- **Small size**: The images are tiny (just 28×28 dots), so fine details are lost.  
- **Room for error**: Like humans, AI can misread unclear or unusual handwriting.  

The important part:  
- This is **normal in AI systems** — even the best models have some error rate.  
- With **more training data** and better models, the accuracy improves.  
- Real-world AI (like in self-driving cars or medical imaging) is trained on **much larger and richer datasets**, which makes them much more reliable.  

Think of this demo as a **fun, educational showcase**: it gives you a glimpse of how AI learns and predicts, but don’t expect it to always be 100% correct.

In [7]:
# ==== KMNIST Inference + Auto-Discovered Models (native 28×28, click to zoom) ====

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import gradio as gr

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

# ------------------------- Paths & discovery -------------------------
def find_dir(dirname: str) -> Path:
    """
    Find a directory named `dirname` in:
      ./dirname, ../dirname, ../../dirname
    Useful because this notebook lives under 'KMNIST/source code/'.
    """
    here = Path.cwd()
    for base in [here, here.parent, here.parent.parent]:
        cand = base / dirname
        if cand.exists() and cand.is_dir():
            return cand
    raise FileNotFoundError(f"Could not find '{dirname}' near {here}")

DATA_DIR   = find_dir("data")
MODELS_DIR = find_dir("models")

def friendly_name_from_filename(p: Path) -> str:
    """Turn 'mlp_swa.pt' -> 'MLP Swa' (nice for UI)."""
    return p.stem.replace("_", " ").title()

def detect_arch_from_filename(p: Path) -> str:
    """
    Return a key indicating which model architecture to build for this weight file.
    Extend this function if you add other architectures (cnn, resnet, etc).
    """
    name = p.stem.lower()
    if any(k in name for k in ["mlp", "perceptron"]):
        return "mlp"
    # Example future extension:
    # if "cnn" in name: return "cnn"
    # if "resnet" in name: return "resnet18"
    return "mlp"  # default

def discover_models(models_dir: Path) -> dict:
    """
    Scan models_dir for *.pt files.
    Returns: {ui_name: {"path": str, "arch": "mlp" | ...}}
    """
    out = {}
    for f in sorted(models_dir.glob("*.pt")):
        out[friendly_name_from_filename(f)] = {
            "path": str(f),
            "arch": detect_arch_from_filename(f),
        }
    if not out:
        raise FileNotFoundError(f"No .pt files found in {models_dir}")
    return out

ALL_MODELS = discover_models(MODELS_DIR)

# ------------------------- Device -------------------------
def pick_device():
    if torch.backends.mps.is_available(): return torch.device("mps")
    if torch.cuda.is_available():         return torch.device("cuda")
    return torch.device("cpu")

device = pick_device()
print("device:", device)

# ------------------------- Architectures -------------------------
class MLP_Wide(nn.Module):
    def __init__(self, p=0.35):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 512), nn.BatchNorm1d(512), nn.GELU(), nn.Dropout(p),
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.GELU(), nn.Dropout(p),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.GELU(), nn.Dropout(p),
            nn.Linear(128, 10)
        )
    def forward(self, x): return self.net(x)

def build_model(arch_key: str) -> nn.Module:
    """Create a model instance for a given arch key."""
    if arch_key == "mlp":
        return MLP_Wide(0.35)
    # Example future extensions:
    # if arch_key == "cnn": return YourCNN()
    # if arch_key == "resnet18": return torchvision.models.resnet18(num_classes=10, ...)
    raise ValueError(f"Unknown architecture key: {arch_key}")

# We will (re)build & load per selection to allow different architectures in the future
_current = {"model": None, "arch": None, "name": None}

def load_selected_model(model_name: str) -> nn.Module:
    """Build (if needed) and load weights for the chosen model; return a ready model on `device`."""
    info = ALL_MODELS[model_name]
    arch = info["arch"]
    path = info["path"]

    if _current["model"] is None or _current["arch"] != arch or _current["name"] != model_name:
        m = build_model(arch).to(device)
        state = torch.load(path, map_location=device)
        m.load_state_dict(state)
        m.eval()
        _current.update({"model": m, "arch": arch, "name": model_name})
    return _current["model"]

# ------------------------- Labels -------------------------
KMNIST_CLASSES = ["o","ki","su","tsu","na","ha","ma","ya","re","wo"]
KANA = {"o":"お","ki":"き","su":"す","tsu":"つ","na":"な","ha":"は","ma":"ま","ya":"や","re":"れ","wo":"を"}
PRON = {"o":"“oh”","ki":"“kee”","su":"“soo”","tsu":"“tsoo”","na":"“nah”","ha":"“hah”","ma":"“mah”","ya":"“yah”","re":"“reh”","wo":"“oh/wo”"}

# ------------------------- Test set (for Sample tab) -------------------------
_KX = _KY = None
def load_kmnist_test():
    global _KX, _KY
    if _KX is not None:
        return _KX, _KY
    imgs = DATA_DIR / "kmnist-test-imgs.npz"
    labs = DATA_DIR / "kmnist-test-labels.npz"
    if not imgs.exists() or not labs.exists():
        raise FileNotFoundError(f"KMNIST test .npz not found in {DATA_DIR}")
    _KX = np.load(imgs)["arr_0"]
    _KY = np.load(labs)["arr_0"]
    return _KX, _KY

# ------------------------- Helpers -------------------------
to_tensor = transforms.ToTensor()

def preprocess_pil(pil: Image.Image):
    """PIL -> ([1,1,28,28] tensor on device, native 28×28 PIL)."""
    pil = pil.convert("L").resize((28, 28), Image.BILINEAR)
    x = to_tensor(pil)[0]
    if x.mean().item() > 0.5:  # invert if background light
        x = 1.0 - x
    native_28 = Image.fromarray((x.cpu().numpy() * 255).astype(np.uint8))
    return x.unsqueeze(0).unsqueeze(0).to(device), native_28

def probs_to_fig(probs, classes):
    idx = np.argsort(-probs)[:10]
    fig, ax = plt.subplots(figsize=(5, 2.8))
    ax.bar(range(len(idx)), probs[idx])
    ax.set_xticks(range(len(idx)))
    ax.set_xticklabels([classes[i] for i in idx])
    ax.set_ylim(0, 1)
    ax.set_ylabel("prob.")
    ax.grid(axis="y", alpha=0.25)
    fig.tight_layout()
    return fig

def small_ref_grid(label_idx: int, grid=2):
    """2×2 grid at native resolution (56×56)."""
    try:
        x_test, y_test = load_kmnist_test()
    except Exception:
        return Image.fromarray(np.zeros((56, 56), np.uint8))
    idxs = np.where(y_test == label_idx)[0]
    if len(idxs) == 0:
        return Image.fromarray(np.zeros((56, 56), np.uint8))
    rng = np.random.default_rng(12345 + label_idx)
    picks = rng.choice(idxs, size=min(grid*grid, len(idxs)), replace=False)
    canvas = Image.new("L", (28*grid, 28*grid), 0)
    r = c = 0
    for idx in picks:
        img = 255 - x_test[idx].astype(np.uint8)  # white-on-black like model input
        canvas.paste(Image.fromarray(img), (c*28, r*28))
        c += 1
        if c == grid:
            c = 0; r += 1
            if r == grid: break
    return canvas

# ------------------------- Predictors -------------------------
@torch.no_grad()
def predict_from_pil(pil_img: Image.Image, which_model: str):
    if pil_img is None:
        return "—", {}, None, None, None
    model = load_selected_model(which_model)
    x, native_28 = preprocess_pil(pil_img)
    probs = F.softmax(model(x), dim=1)[0].cpu().numpy()
    top1 = int(np.argmax(probs))
    romaji = KMNIST_CLASSES[top1]
    md = f"**Top-1:** `{romaji}` {KANA[romaji]} *(p={probs[top1]:.3f}; {PRON[romaji]})*"
    top3 = {KMNIST_CLASSES[i]: float(probs[i]) for i in np.argsort(-probs)[:3]}
    fig  = probs_to_fig(probs, KMNIST_CLASSES)
    ref  = small_ref_grid(top1, grid=2)
    return md, top3, fig, native_28, ref

@torch.no_grad()
def predict_from_sketch(data, which_model: str):
    if data is None:
        return "—", {}, None, None, None

    # Accept multiple payload variants from gr.Sketchpad
    arr = None
    if isinstance(data, dict):
        if data.get("image") is not None:
            arr = data["image"]
        elif data.get("composite") is not None:
            arr = data["composite"]
        elif data.get("layers") is not None:
            imgs = []
            for layer in data["layers"]:
                li = layer.get("image") if isinstance(layer, dict) else layer
                if li is not None:
                    imgs.append(np.asarray(li))
            if imgs:
                arr = np.maximum.reduce(imgs)
    else:
        arr = np.asarray(data)

    if arr is None:
        return "—", {}, None, None, None

    arr = np.asarray(arr)
    if arr.ndim == 3:
        arr = arr.mean(axis=2)
    arr = arr.astype(np.float32)
    if arr.max() <= 1.0:
        arr = (arr * 255).astype(np.uint8)
    else:
        arr = arr.astype(np.uint8)

    pil = Image.fromarray(arr).convert("L")
    return predict_from_pil(pil, which_model)

@torch.no_grad()
def predict_from_label(label_str: str, which_model: str):
    try:
        x_test, y_test = load_kmnist_test()
    except Exception as e:
        return f"— (test set not found: {e})", {}, None, None, None, None
    if label_str not in KMNIST_CLASSES:
        return "—", {}, None, None, None, None

    label_idx = KMNIST_CLASSES.index(label_str)
    idxs = np.where(y_test == label_idx)[0]
    if len(idxs) == 0:
        return "—", {}, None, None, None, None

    pil28 = Image.fromarray(x_test[np.random.choice(idxs)]).convert("L")
    pil28_inv = Image.fromarray(255 - np.array(pil28))  # native 28×28, white-on-black
    md, top3, fig, native_28, ref = predict_from_pil(pil28, which_model)
    return md, top3, fig, native_28, ref, pil28_inv

# ------------------------- UI -------------------------
with gr.Blocks() as demo:
    gr.Markdown("## KMNIST — Handwritten Kana (ひらがな) Recognizer")
    gr.Markdown(
        "Upload, draw, or sample a character. Images are shown at **true 28×28**; "
        "use the ⤢ button on any image to zoom. Choose a model below — "
        "the list is auto-discovered from your `models/` folder."
    )

    with gr.Row():
        with gr.Column(scale=1):
            model_choice = gr.Dropdown(
                choices=list(ALL_MODELS.keys()),
                value=list(ALL_MODELS.keys())[0],
                label="Select Model",
                multiselect=False   # ensures single select
            )

            with gr.Tab("Upload"):
                up = gr.Image(type="pil", image_mode="L", label="Upload")
                btn_up = gr.Button("Predict from upload")
            with gr.Tab("Draw"):
                pad = gr.Sketchpad(label="Draw")
                btn_pad = gr.Button("Predict from drawing")
            with gr.Tab("Sample"):
                sample_label = gr.Dropdown(choices=KMNIST_CLASSES, value="ki", label="Pick class")
                btn_sample = gr.Button("Sample & predict")

        with gr.Column(scale=1):
            top1_out  = gr.Markdown("Top-1: —")
            top3_out  = gr.Label(num_top_classes=3, label="Top-3")
            chart_out = gr.Plot(label="Probabilities")
            your28_out = gr.Image(type="pil", label="Your 28×28 (native)",
                                  show_fullscreen_button=True, interactive=False)
            ref_out    = gr.Image(type="pil", label="Reference (2×2, 56×56 native)",
                                  show_fullscreen_button=True, interactive=False)
            sample_out = gr.Image(type="pil", label="Sampled test image (28×28 native)",
                                  show_fullscreen_button=True, interactive=False)

    btn_up.click(predict_from_pil,     inputs=[up,  model_choice],
                 outputs=[top1_out, top3_out, chart_out, your28_out, ref_out])
    btn_pad.click(predict_from_sketch, inputs=[pad, model_choice],
                 outputs=[top1_out, top3_out, chart_out, your28_out, ref_out])
    btn_sample.click(predict_from_label, inputs=[sample_label, model_choice],
                 outputs=[top1_out, top3_out, chart_out, your28_out, ref_out, sample_out])

demo.launch()


device: mps
* Running on local URL:  http://127.0.0.1:7894
* To create a public link, set `share=True` in `launch()`.


