In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import os


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#config 


def fetchDevice(deviceName = "Default", logging = False) -> torch.device:
    device:torch.device = None
    if deviceName == "Default":
        if (torch.backends.mps.is_available()):
            device = torch.device('mps')
            if logging:
                print("Metal Performance Shaders Available! ")
        elif(torch.cuda.is_available()):
            device = torch.device('cuda')
            if logging: 
                print("NVIDIA CUDA Available! ")
        else:
            device = torch.device('cpu')
            if logging:
                print("ONLY CPU Available! ")

    else:
        device = torch.device(deviceName)

    return device
# get gpu 
DEVICE = fetchDevice()
print("Device:", DEVICE)

# Skip frame for inferencing(saves computing power)
FRAME_SKIP = 5

# "cnn" or "densenet"
MODEL_TYPE = "densenet"   

# train/webcam
MODE = "webcam"

Device: mps


In [3]:
def get_cnn():
    class SimpleCNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.net = nn.Sequential(
                nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
                nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Flatten(),
                nn.Linear(9216, 128), nn.ReLU(),
                nn.Linear(128, 10)
            )
        def forward(self, x):
            return self.net(x)
    return SimpleCNN()

def get_densenet():
    model = models.densenet121(pretrained=False, num_classes=10)
    model.features.conv0 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.features.pool0 = nn.Identity()
    return model

MODEL_PATHS = {"cnn": "cnn_mnist.pth", "densenet": "densenet_mnist.pth"}
MODELS = {"cnn": get_cnn, "densenet": get_densenet}

In [4]:
# Only needed if neither model weight exists
def train_and_save(model_type):
    epochs = 5 if model_type == "cnn" else 3

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)

    model = MODELS[model_type]().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):  # Few epochs for quick demo
        model.train()
        total, correct = 0, 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            output = model(imgs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            preds = output.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        print(f"[{model_type}] Epoch {epoch+1} train acc: {correct/total:.4f}")
    torch.save(model.state_dict(), MODEL_PATHS[model_type])
    print(f"{model_type} model saved to {MODEL_PATHS[model_type]}")

# Only train if file does not exist
for mt in ["cnn", "densenet"]:
    if not os.path.exists(MODEL_PATHS[mt]):
        train_and_save(mt)
    else:
        print(f"{MODEL_PATHS[mt]} already exists. Skipping training.")


cnn_mnist.pth already exists. Skipping training.
densenet_mnist.pth already exists. Skipping training.


In [5]:
# Select model based on MODEL_TYPE, always loads weights (never trains)
model = MODELS[MODEL_TYPE]().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATHS[MODEL_TYPE], map_location=DEVICE))
model.eval()
print(f"{MODEL_TYPE} model loaded from {MODEL_PATHS[MODEL_TYPE]}")




densenet model loaded from densenet_mnist.pth


In [6]:
def preprocess_digit(roi):
    im = cv2.resize(roi, (28,28))
    im = im.astype(np.float32) / 255.0
    im = (im - 0.1307) / 0.3081
    im = np.expand_dims(im, axis=0)  
    return torch.from_numpy(im).unsqueeze(0).to(DEVICE)  # 1,1,28,28

def find_digits(
    frame_gray,
    thresh_block_size=13,      # Must be odd, 11, 13, 15, 17...
    thresh_C=15,                # Subtract from mean; higher = fewer detections (darker)
    morph_ksize=5,             # Morph closing kernel size (helps connect strokes, fill gaps)
    min_w=0, max_w=80,        # Min/max width of digit boxes in pixels
    min_h=20, max_h=100,        # Min/max height of digit boxes in pixels
    min_aspect=0.0, max_aspect=5.2,  # Width/height ratio: thin digits like "1" vs round like "0"
    min_area=10                # Minimum contour area to be considered a digit
):


    thresh = cv2.adaptiveThreshold(
        frame_gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV,
        thresh_block_size, thresh_C
    )

    if morph_ksize > 0:
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (morph_ksize, morph_ksize))
        thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=1)

    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    digit_bboxes = []
    for c in contours:
        x, y, w, h = cv2.boundingRect(c)
        area = cv2.contourArea(c)
        aspect = w / float(h)
        if (min_w <= w <= max_w and
            min_h <= h <= max_h and
            min_aspect <= aspect <= max_aspect and
            area >= min_area):
            digit_bboxes.append((x, y, w, h))
    digit_bboxes = sorted(digit_bboxes, key=lambda b: b[0])
    return digit_bboxes, thresh



In [9]:
if MODE == "webcam":
    model.eval()
    cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
    frame_count = 0
    last_digits = []

    print("Press Q to quit.")
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        frame_disp = frame.copy()

        if frame_count % FRAME_SKIP == 0:
            digit_bboxes, thresh = find_digits(frame_gray)
            digits = []
            for bbox in digit_bboxes:
                x, y, w, h = bbox
                roi = thresh[y:y+h, x:x+w]
                roi_pad = cv2.copyMakeBorder(roi, 10, 10, 10, 10, cv2.BORDER_CONSTANT, value=0)
                im_tensor = preprocess_digit(roi_pad)
                with torch.no_grad():
                    pred = model(im_tensor)
                    digit = pred.argmax(1).item()
                digits.append((digit, x, y, w, h))
            last_digits = digits  

        for digit, x, y, w, h in last_digits:
            cv2.rectangle(frame_disp, (x, y), (x + w, y + h), (0, 255, 0), 2)
            cv2.putText(frame_disp, str(digit), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)

        frame_count += 1
        cv2.imshow('Digit Recognition', frame_disp)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()



Press Q to quit.


KeyboardInterrupt: 