In [42]:
import sys
from pathlib import Path

# Ensure repo root is on Python path
REPO_ROOT = Path("..").resolve()
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

print("Using repo root:", REPO_ROOT)

Using repo root: /Users/aaditya/Workspace/Learning/Github/projects/kmnist-recognizer


In [44]:
from pathlib import Path
import torch

from src.kmnist.data import load_kmnist_npz, split_train_val, make_loader
from src.kmnist.models.mlp import MLPWide
from src.kmnist.train import train_model

# repo-relative paths
REPO_ROOT  = Path("..")
DATA_DIR   = REPO_ROOT / "data"
MODELS_DIR = REPO_ROOT / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

# device setup
device = torch.device("mps" if torch.backends.mps.is_available()
                      else "cuda" if torch.cuda.is_available()
                      else "cpu")
device

device(type='mps')

In [45]:
# load training images + labels
x, y = load_kmnist_npz(DATA_DIR, split="train")

# split into train/val
tr_x, tr_y, va_x, va_y = split_train_val(x, y, val_frac=0.1, seed=42)

print("Train:", tr_x.shape, tr_y.shape)
print("Val:  ", va_x.shape, va_y.shape)

# sanity checks
assert tr_x.shape[1:] == (28, 28)
assert va_x.shape[1:] == (28, 28)
assert tr_y.ndim == 1 and va_y.ndim == 1

Train: (54000, 28, 28) (54000,)
Val:   (6000, 28, 28) (6000,)


In [46]:
train_loader = make_loader(tr_x, tr_y, batch_size=128, shuffle=True)
val_loader   = make_loader(va_x, va_y, batch_size=256, shuffle=False)

len(train_loader), len(val_loader)

(422, 24)

In [47]:
model = MLPWide(p=0.35).to(device)
model

MLPWide(
  (net): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=512, bias=True)
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): GELU(approximate='none')
    (4): Dropout(p=0.35, inplace=False)
    (5): Linear(in_features=512, out_features=256, bias=True)
    (6): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): GELU(approximate='none')
    (8): Dropout(p=0.35, inplace=False)
    (9): Linear(in_features=256, out_features=128, bias=True)
    (10): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): GELU(approximate='none')
    (12): Dropout(p=0.35, inplace=False)
    (13): Linear(in_features=128, out_features=10, bias=True)
  )
)

In [48]:
_ = train_model(
    model,
    train_loader,
    epochs=40,                          # max epochs
    lr=1e-3,                            # learning rate
    val_loader=val_loader,
    out_path=MODELS_DIR / "multilayer_perceptron_model.pt",
    early_stopping=True,                # enable early stopping
    patience=5,                         # stop if no val improvement in 5 epochs
    min_delta=1e-3,                     # improvement threshold
    save_best_only=True                 # persist best model only
)

Epoch 01 | train acc 0.853 loss 0.498 | val acc 0.905 loss 0.295
Epoch 02 | train acc 0.919 loss 0.263 | val acc 0.935 loss 0.205
Epoch 03 | train acc 0.938 loss 0.200 | val acc 0.927 loss 0.225
Epoch 04 | train acc 0.948 loss 0.168 | val acc 0.950 loss 0.167
Epoch 05 | train acc 0.954 loss 0.144 | val acc 0.952 loss 0.158
Epoch 06 | train acc 0.961 loss 0.125 | val acc 0.948 loss 0.172
Epoch 07 | train acc 0.964 loss 0.113 | val acc 0.943 loss 0.185
Epoch 08 | train acc 0.967 loss 0.107 | val acc 0.956 loss 0.151
Epoch 09 | train acc 0.969 loss 0.095 | val acc 0.958 loss 0.140
Epoch 10 | train acc 0.973 loss 0.085 | val acc 0.957 loss 0.146
Epoch 11 | train acc 0.975 loss 0.079 | val acc 0.961 loss 0.134
Epoch 12 | train acc 0.976 loss 0.074 | val acc 0.963 loss 0.131
Epoch 13 | train acc 0.978 loss 0.067 | val acc 0.951 loss 0.170
Epoch 14 | train acc 0.980 loss 0.064 | val acc 0.958 loss 0.140
Epoch 15 | train acc 0.981 loss 0.059 | val acc 0.964 loss 0.135
Epoch 16 | train acc 0.98

In [49]:
best = torch.load(MODELS_DIR / "multilayer_perceptron_model.pt", map_location=device)
model.load_state_dict(best)
model.eval()

print("✅ Best model reloaded and ready for inference.")

✅ Best model reloaded and ready for inference.
