<a href="https://www.kaggle.com/code/ahmedelmaamounamin/handwritten-digit-recognition-with-gui?scriptVersionId=146504834" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Handwritten Digit Recognition with GUI

This project implements a handwritten digit recognition system using the MNIST dataset and provides a graphical user interface (GUI) for users to draw digits and have them recognized by a trained neural network model. The project is implemented in Python and utilizes various libraries, including PyTorch and Tkinter.

## Project Components
1. **Neural Network Model**

   A Convolutional Neural Network (CNN) model is defined using PyTorch. This model is designed for handwritten digit recognition.

   The model consists of convolutional layers, fully connected layers, and dropout layers to prevent overfitting.

2. **Data Loading and Preprocessing**

   The MNIST dataset is loaded and preprocessed. It includes normalization and data augmentation for training.

3. **Training**

   The model is trained on the MNIST dataset. The training loop runs for a specified number of epochs, and the model is optimized using the Adam optimizer and Cross-Entropy loss.

4. **Saving and Loading Model**

   The trained model is saved to a file for later use.

5. **GUI for Digit Recognition**

   A GUI is created using the Tkinter library, allowing users to draw digits on a canvas.

6. **Digit Prediction**

   The drawn digit is processed, resized to the appropriate dimensions, and converted into a PyTorch tensor.

   The trained model is used to predict the digit, and the result is displayed to the user.

7. **Clear and Predict Buttons**

   The GUI features buttons to clear the canvas and initiate the digit prediction.

## How to Use
1. Execute the Python code and remove the comments in the code section.
2. The GUI will appear, allowing you to draw a digit on the canvas.
3. Click the "Clear" button to erase the drawing and start over.
4. Click the "Predict" button to have the model recognize and display the predicted digit.

This project combines machine learning, computer vision, and GUI development to create an interactive handwritten digit recognition tool. Users can draw digits, and the model predicts the written number, making it a great demonstration of AI in action.


In [1]:
# Import libraries 

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import numpy as np
from PIL import Image, ImageDraw
import tkinter as tk

In [2]:
# Define a CNN model in PyTorch
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(12*12*64, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = self.dropout(x)
        x = x.view(-1, 12*12*64)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

In [3]:
# Load MNIST data and apply transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transform),
    batch_size=128, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 89374958.34it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 104247585.05it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 24002094.13it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 13588108.96it/s]

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






In [4]:
# Create the model, optimizer, and loss function
model = Net()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training loop
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Train the model
for epoch in range(1, 11):
    train(epoch)

# Save the model
torch.save(model.state_dict(), "mnist_model.pth")

# Load the model
model = Net()
model.load_state_dict(torch.load("mnist_model.pth"))
model.eval()



Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

# Digit Recognition: Draw and Predict

# To use the Handwritten Digit Recognition GUI, follow these steps:

1. Remove the comments (''' triple quotes ''') in the code below.
2. Run the code cell.

After running the code, the GUI will appear, allowing you to draw a digit on the canvas. You can use the following functions:

- **Clear:** Click the "Clear" button to erase the drawing and start over.
- **Predict:** Click the "Predict" button to have the model recognize and display the predicted digit.

In [5]:
'''
def predict_digit(img):
    # Resize and invert colors
    img = img.resize((28, 28))
    img = img.convert('L')
    img = np.array(img)
    img = img.reshape(1, 1, 28, 28)
    img = img / 255.0

    # Convert to PyTorch tensor
    img = torch.FloatTensor(img)

    # Predicting the digit
    with torch.no_grad():
        output = model(img)
    _, predicted = torch.max(output, 1)
    return predicted.item()

def draw(event):
    x = event.x
    y = event.y
    draw_canvas.line([(x, y), (x+1, y+1)], fill='black', width=8)

def clear():
    global image, draw_canvas
    image = Image.new("RGB", (200, 200), (255, 255, 255))
    draw_canvas = ImageDraw.Draw(image)
    canvas.delete("all")

def predict():
    digit = predict_digit(image)
    label.configure(text=str(digit))



root = tk.Tk()

canvas = tk.Canvas(root, width=200, height=200, bg='white')
canvas.grid(row=0, column=0, pady=2, sticky=tk.W)

button_clear = tk.Button(root, text="Clear", command=clear)
button_clear.grid(row=1, column=0, pady=2)

button_predict = tk.Button(root, text="Predict", command=predict)
button_predict.grid(row=1, column=1, pady=2)

label = tk.Label(root, text="", font=("Helvetica", 48))
label.grid(row=0, column=1, pady=2, padx=2)

canvas.bind("<B1-Motion>", draw)

image = Image.new("RGB", (200, 200), (255, 255, 255))
draw_canvas = ImageDraw.Draw(image)

root.mainloop()

'''

'\ndef predict_digit(img):\n    # Resize and invert colors\n    img = img.resize((28, 28))\n    img = img.convert(\'L\')\n    img = np.array(img)\n    img = img.reshape(1, 1, 28, 28)\n    img = img / 255.0\n\n    # Convert to PyTorch tensor\n    img = torch.FloatTensor(img)\n\n    # Predicting the digit\n    with torch.no_grad():\n        output = model(img)\n    _, predicted = torch.max(output, 1)\n    return predicted.item()\n\ndef draw(event):\n    x = event.x\n    y = event.y\n    draw_canvas.line([(x, y), (x+1, y+1)], fill=\'black\', width=8)\n\ndef clear():\n    global image, draw_canvas\n    image = Image.new("RGB", (200, 200), (255, 255, 255))\n    draw_canvas = ImageDraw.Draw(image)\n    canvas.delete("all")\n\ndef predict():\n    digit = predict_digit(image)\n    label.configure(text=str(digit))\n\n\n\nroot = tk.Tk()\n\ncanvas = tk.Canvas(root, width=200, height=200, bg=\'white\')\ncanvas.grid(row=0, column=0, pady=2, sticky=tk.W)\n\nbutton_clear = tk.Button(root, text="Clear