In [None]:
import cv2
import numpy as np
from scipy import ndimage
import matplotlib.pyplot as plt
from tensorflow import keras

In [None]:
def _sharpness_score(img_gray):

    return cv2.Laplacian(img_gray, cv2.CV_64F).var()

def _deskew(binary255):

    m = cv2.moments(binary255)
    if abs(m["mu02"]) < 1e-2: 
        return binary255

    angle = 0.5 * np.arctan2(2*m["mu11"], (m["mu20"] - m["mu02"]))
    angle_deg = angle * 180.0 / np.pi

    (h, w) = binary255.shape
    M = cv2.getRotationMatrix2D((w//2, h//2), angle_deg, 1.0)
    rotated = cv2.warpAffine(binary255, M, (w, h), flags=cv2.INTER_NEAREST, borderValue=0)
    return rotated

def preprocess_to_mnist(roi_bgr):


    gray = cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    gray = clahe.apply(gray)

    bin_inv = cv2.adaptiveThreshold(gray, 255,
                                    cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                    cv2.THRESH_BINARY_INV, 31, 10)

    if np.mean(bin_inv) < 5:
        _, bin_inv = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)


    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
    bin_inv = cv2.morphologyEx(bin_inv, cv2.MORPH_OPEN, k, iterations=1)
    bin_inv = cv2.dilate(bin_inv, k, iterations=2)   # <--- added


    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bin_inv, connectivity=8)
    if num_labels <= 1:

        canvas = np.zeros((28,28), dtype=np.float32)
        return canvas[None, ..., None], {"roi": roi_bgr, "binary": bin_inv, "final28": canvas}


    areas = stats[1:, cv2.CC_STAT_AREA]
    idx = 1 + np.argmax(areas)
    x, y, w, h, area = stats[idx, cv2.CC_STAT_LEFT], stats[idx, cv2.CC_STAT_TOP], \
                       stats[idx, cv2.CC_STAT_WIDTH], stats[idx, cv2.CC_STAT_HEIGHT], \
                       stats[idx, cv2.CC_STAT_AREA]

    digit = bin_inv[y:y+h, x:x+w]

    digit = _deskew(digit)

    h2, w2 = digit.shape
    if h2 == 0 or w2 == 0:
        canvas = np.zeros((28,28), dtype=np.float32)
        return canvas[None, ..., None], {"roi": roi_bgr, "binary": bin_inv, "final28": canvas}

    if h2 > w2:
        new_h, new_w = 20, max(1, int(round(w2 * (20.0 / h2))))
    else:
        new_w, new_h = 20, max(1, int(round(h2 * (20.0 / w2))))
    digit = cv2.resize(digit, (new_w, new_h), interpolation=cv2.INTER_NEAREST)


    canvas = np.zeros((28, 28), dtype=np.uint8)
    y_off = (28 - new_h) // 2
    x_off = (28 - new_w) // 2
    canvas[y_off:y_off+new_h, x_off:x_off+new_w] = digit


    cy, cx = ndimage.center_of_mass(canvas)
    if np.isnan(cx) or np.isnan(cy):
        cx, cy = 14, 14
    shiftx, shifty = int(round(14 - cx)), int(round(14 - cy))
    M = np.float32([[1, 0, shiftx], [0, 1, shifty]])
    canvas = cv2.warpAffine(canvas, M, (28, 28), flags=cv2.INTER_NEAREST, borderValue=0)


    canvas_f = (canvas.astype(np.float32) / 255.0)
    model_input = canvas_f[None, ..., None]

    debug = {"roi": roi_bgr, "binary": bin_inv, "final28": canvas}
    return model_input, debug

In [None]:

model = keras.models.load_model("mnist_model.keras")
print("✅ Model loaded successfully!")


In [None]:
def test_with_webcam(model, roi_size=300, frames_to_sample=7):
    cap = cv2.VideoCapture(0)


    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
    cap.set(cv2.CAP_PROP_FPS, 30)

    if not cap.isOpened():
        print("Error: Could not open webcam")
        return

    print("Webcam opened. Place ONE digit inside the square.")
    print("Press 'c' to capture (it will pick the sharpest of several frames), 'q' to quit.")

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Error: Failed to read frame.")
            break

        h, w = frame.shape[:2]

        cx, cy = w // 2, h // 2
        half = roi_size // 2
        x1, y1 = max(0, cx - half), max(0, cy - half)
        x2, y2 = min(w, cx + half), min(h, cy + half)
        roi = frame[y1:y2, x1:x2].copy()


        overlay = frame.copy()
        cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(overlay, "Put a single digit fully inside the box",
                    (max(10, x1), max(30, y1-10)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2)
        cv2.putText(overlay, "Press 'c' to capture, 'q' to quit",
                    (10, h-15), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2)

        cv2.imshow("Digit Capture", overlay)
        key = cv2.waitKey(1) & 0xFF

        if key == ord('q'):
            break

        if key == ord('c'):

            samples = []
            for _ in range(frames_to_sample):
                ret2, fr = cap.read()
                if not ret2:
                    continue
                roi2 = fr[y1:y2, x1:x2]
                gray2 = cv2.cvtColor(roi2, cv2.COLOR_BGR2GRAY)
                score = _sharpness_score(gray2)
                samples.append((score, roi2))
            if not samples:
                print("No frames captured; try again.")
                continue
            samples.sort(key=lambda x: x[0], reverse=True)
            best_roi = samples[0][1]


            model_input, dbg = preprocess_to_mnist(best_roi)
            preds = model.predict(model_input, verbose=0)[0]
            digit = int(np.argmax(preds))
            conf = float(np.max(preds))


            plt.figure(figsize=(10, 3.2))

            plt.subplot(1, 3, 1)
            plt.imshow(cv2.cvtColor(dbg["roi"], cv2.COLOR_BGR2RGB))
            plt.title("Original ROI"); plt.axis("off")

            plt.subplot(1, 3, 2)
            plt.imshow(dbg["binary"], cmap='gray')
            plt.title("Binary (after cleanup)"); plt.axis("off")

            plt.subplot(1, 3, 3)
            plt.imshow(dbg["final28"], cmap='gray', vmin=0, vmax=255)
            plt.title("Final 28x28 (MNIST-like)"); plt.axis("off")

            plt.tight_layout()
            plt.show()

            plt.figure(figsize=(5, 3))
            plt.bar(range(10), preds)
            plt.xticks(range(10))
            plt.title(f"Prediction: {digit}  |  Confidence: {conf:.2f}")
            plt.xlabel("Digit"); plt.ylabel("Probability")
            plt.tight_layout()
            plt.show()

            print(f"Predicted digit: {digit} | confidence: {conf:.2f}")

    cap.release()
    cv2.destroyAllWindows()
    print("Webcam released.")
    

test_with_webcam(model)