In [11]:
import io

import ipywidgets as widgets
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from IPython.display import display, clear_output
from PIL import Image, ImageGrab, ImageOps
from torchvision import transforms
import numpy as np

In [12]:
class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


In [13]:
print(f"Torch Version: {torch.__version__}")
print(f"CUDA Version (torch): {torch.version.cuda}")
if torch.cuda.is_available():
    print(f"Dispositivo: {torch.cuda.get_device_name(0)}")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

Torch Version: 2.10.0+cu130
CUDA Version (torch): 13.0
Dispositivo: NVIDIA GeForce RTX 3050 6GB Laptop GPU
Using cuda device


In [14]:
model = MNISTModel().to(device)
model.load_state_dict(torch.load("../weights/mnist_model.pth", map_location=device))
model.eval()


MNISTModel(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

In [15]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [16]:
upload = widgets.FileUpload(
    accept='image/*',
    multiple=False
)

paste_btn = widgets.Button(
    description="üìã Colar imagem (Ctrl+V)",
    button_style="info"
)

output = widgets.Output()

In [17]:
current_image = None


def preprocess_for_model(img: Image.Image) -> Image.Image:
    """Return a PIL grayscale 28x28 image suitable for the model.

    This function converts the incoming image to grayscale, resizes to 28x28
    and auto-inverts colors if the image background is bright (white).
    """
    # Convert to grayscale and ensure no alpha channel
    gray = img.convert("L")
    # Resize to model input size
    gray = gray.resize((28, 28))

    # Decide whether to invert: if the mean pixel value is high (close to white)
    arr = np.array(gray) / 255.0
    if arr.mean() > 0.5:
        gray = ImageOps.invert(gray)

    return gray


def load_from_upload(change):
    global current_image
    file = next(iter(upload.value.values()))
    current_image = Image.open(io.BytesIO(file['content']))
    show_image()


def load_from_clipboard(button):
    global current_image
    img = ImageGrab.grabclipboard()
    if img is None:
        with output:
            print("‚ùå Nenhuma imagem no clipboard")
        return
    current_image = img
    show_image()


def show_image():
    with output:
        clear_output()
        plt.imshow(current_image, cmap="gray")
        plt.axis("off")
        plt.show()

In [18]:
def predict(button):
    if current_image is None:
        with output:
            print("‚ö†Ô∏è Nenhuma imagem carregada")
        return

    # Preprocess image (grayscale, resize, and optional invert) to match training distribution
    prepped = preprocess_for_model(current_image)

    img = transform(prepped).unsqueeze(0).to(device)

    with torch.no_grad():
        pred = model(img).argmax(1).item()

    with output:
        print(f"üß† N√∫mero reconhecido: {pred}")


In [19]:
predict_btn = widgets.Button(
    description="üß† Classificar",
    button_style="success"
)

upload.observe(load_from_upload, names='value')
paste_btn.on_click(load_from_clipboard)
predict_btn.on_click(predict)

display(
    widgets.VBox([
        widgets.HTML("<h3>Reconhecimento de D√≠gitos (MNIST)</h3>"),
        upload,
        paste_btn,
        predict_btn,
        output
    ])
)

VBox(children=(HTML(value='<h3>Reconhecimento de D√≠gitos (MNIST)</h3>'), FileUpload(value=(), accept='image/*'‚Ä¶