# MNIST Digit Classifier

In [1]:
# This cell assumes a project structure of: project-root/src/experiments/this_notebook.ipynb
# We append the parent directory to the system path, so now we can import modules from src
# We also create a variable named path which points to the project root.

import sys
from pathlib import Path

sys.path.append("../") # go to parent dir
path =  str(Path().resolve().parent.parent)

print(path)


/workspaces/Kaggle-Knowledge-Competitions


Class of configurations (consider using Hydra for heavier workloads in future):

In [2]:
from dataclasses import dataclass

@dataclass
class Configurations:
    data_dir: str = path + "/data/kaggle_mnist"
    batch_size: int = 64
    num_workers: int = 4
    pin_memory: bool = True

    lr: float = 0.01
    momentum: float = 0.9

    num_epochs: int = 5

    log_dir: str = path + "/logs/flax-digit-classifier"
    log_every_n_steps: int = 50

cfg = Configurations()

In [3]:
import os
import jax
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard.writer import SummaryWriter

from trainer.digit_classifier_trainer import train_digit_classifier
from datasets.kaggle_mnist import KaggleMNIST
from models.digit_classifier import ResNet18

import socket
from datetime import datetime
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = os.path.join(
                cfg.log_dir, current_time + '_' + socket.gethostname())

model = ResNet18(num_classes=10)
data = KaggleMNIST(data_dir=cfg.data_dir, train=True, transform=None)
logger = SummaryWriter(log_dir=log_dir)

n_val = int(len(data) * 0.2)
n_train = len(data) - n_val
train_data, val_data = random_split(data, [n_train, n_val])

train_loader = DataLoader(
    train_data,
    shuffle=True,
    batch_size=cfg.batch_size,
    num_workers=cfg.num_workers,
    pin_memory=cfg.pin_memory,
)

val_loader = DataLoader(
    val_data,
    shuffle=False,
    batch_size=cfg.batch_size,
    num_workers=cfg.num_workers,
    pin_memory=cfg.pin_memory,
)

train_digit_classifier(
    model,
    train_loader,
    val_loader,
    jax.random.PRNGKey(0),
    cfg.lr,
    cfg.momentum,
    num_epochs=cfg.num_epochs,
    logger=logger,
    log_every_n_steps=cfg.log_every_n_steps,
)