In [1]:
!pip install torch torchvision matplotlib numpy



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class DigitRecognizer(nn.Module):
    def __init__(self):
        super(DigitRecognizer, self).__init__()
        self.flatten = nn.Flatten()
        self.network = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        x = self.flatten(x)
        return self.network(x)

In [None]:
def load_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_data = datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )
    test_data = datasets.MNIST(
        root='./data', train=False, download=True, transform=transform
    )
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

In [None]:
def train(model, train_loader, epochs=5, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    print(f"Training on {device}")
    print("-" * 40)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
        
        accuracy = 100. * correct / total
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%")
    
    return model

In [None]:

def evaluate(model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    accuracy = 100. * correct / total
    print(f"\nTest Accuracy: {accuracy:.2f}%")
    return accuracy

In [None]:
def predict_digit(model, image):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    with torch.no_grad():
        image = image.to(device)
        if image.dim() == 3:
            image = image.unsqueeze(0)
        output = model(image)
        probabilities = torch.softmax(output, dim=1)
        predicted = output.argmax(dim=1).item()
        confidence = probabilities[0][predicted].item() * 100
    
    return predicted, confidence

In [None]:
def visualize_predictions(model, test_loader, num_samples=10):
    model.eval()
    data_iter = iter(test_loader)
    images, labels = next(data_iter)
    
    fig, axes = plt.subplots(2, 5, figsize=(12, 6))
    axes = axes.flatten()
    
    for i in range(num_samples):
        img = images[i]
        true_label = labels[i].item()
        pred, conf = predict_digit(model, img)
        
        axes[i].imshow(img.squeeze(), cmap='gray')
        color = 'green' if pred == true_label else 'red'
        axes[i].set_title(f"Pred: {pred} ({conf:.1f}%)\nTrue: {true_label}", color=color)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.savefig('predictions.png', dpi=150)
    plt.show()
    print("Predictions saved to 'predictions.png'")

In [None]:
def save_model(model, path='digit_recognizer.pth'):
    torch.save(model.state_dict(), path)
    print(f"Model saved to '{path}'")


In [None]:
def load_model(path='digit_recognizer.pth'):

    model = DigitRecognizer()
    model.load_state_dict(torch.load(path, weights_only=True))
    print(f"Model loaded from '{path}'")
    return model

In [None]:
if __name__ == "__main__":
    print("=" * 40)
    print("Handwritten Digit Recognition")
    print("=" * 40)

    print("\n[1/4] Loading MNIST dataset...")
    train_loader, test_loader = load_data(batch_size=64)
    print(f"Training samples: {len(train_loader.dataset)}")
    print(f"Test samples: {len(test_loader.dataset)}")

    print("\n[2/4] Creating neural network...")
    model = DigitRecognizer()
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {total_params:,}")

    print("\n[3/4] Training model...")
    model = train(model, train_loader, epochs=5, lr=0.001)

    print("\n[4/4] Evaluating model...")
    evaluate(model, test_loader)

    save_model(model)

    print("\nGenerating prediction visualization...")
    visualize_predictions(model, test_loader)
    
    

# üé® Interactive Drawing Interface

Now let's create an interactive drawing interface where you can draw digits and get real-time predictions!

In [None]:
# Install additional packages for the interactive interface
!pip install pillow ipywidgets

In [None]:
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageDraw, ImageFilter
import threading
from IPython.display import display, clear_output
import ipywidgets as widgets
from io import BytesIO
import base64

In [None]:
class InteractiveDrawingApp:
    def __init__(self, model):
        self.model = model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.canvas_size = 280
        self.brush_size = 12
        self.drawing = False
        self.last_x = None
        self.last_y = None
        
        # Create PIL image for drawing
        self.pil_image = Image.new("RGB", (self.canvas_size, self.canvas_size), "white")
        self.pil_draw = ImageDraw.Draw(self.pil_image)
        
        # Initialize widgets
        self.setup_widgets()
        
    def setup_widgets(self):
        """Setup the interactive widgets"""
        # Create output widget for predictions
        self.output = widgets.Output()
        
        # Create buttons
        self.clear_btn = widgets.Button(
            description="üóëÔ∏è Clear Canvas",
            button_style='danger',
            layout=widgets.Layout(width='150px')
        )
        self.predict_btn = widgets.Button(
            description="üîç Predict Now",
            button_style='success',
            layout=widgets.Layout(width='150px')
        )
        self.draw_btn = widgets.Button(
            description="üñäÔ∏è Open Drawing Canvas",
            button_style='info',
            layout=widgets.Layout(width='200px')
        )
        
        # Bind button events
        self.clear_btn.on_click(self.clear_canvas)
        self.predict_btn.on_click(self.predict_digit)
        self.draw_btn.on_click(self.open_drawing_window)
        
        # Create prediction display
        self.prediction_label = widgets.HTML(
            value="<h3>Draw a digit to see prediction</h3>",
            layout=widgets.Layout(width='300px')
        )
        
        self.confidence_label = widgets.HTML(
            value="<p><strong>Confidence:</strong> --</p>",
            layout=widgets.Layout(width='300px')
        )
        
        # Create probability bars for all digits
        self.prob_bars = {}
        self.prob_labels = {}
        for i in range(10):
            self.prob_bars[i] = widgets.FloatProgress(
                value=0,
                min=0,
                max=100,
                description=f'Digit {i}:',
                bar_style='info',
                layout=widgets.Layout(width='250px')
            )
            self.prob_labels[i] = widgets.HTML(value="0.0%")
        
        # Layout
        button_box = widgets.HBox([self.draw_btn, self.clear_btn, self.predict_btn])
        prediction_box = widgets.VBox([self.prediction_label, self.confidence_label])
        
        prob_boxes = []
        for i in range(10):
            prob_box = widgets.HBox([self.prob_bars[i], self.prob_labels[i]])
            prob_boxes.append(prob_box)
        
        prob_container = widgets.VBox(prob_boxes)
        
        self.main_widget = widgets.VBox([
            button_box,
            widgets.HTML("<hr>"),
            prediction_box,
            widgets.HTML("<h4>All Digit Probabilities:</h4>"),
            prob_container,
            self.output
        ])
    
    def display(self):
        """Display the interactive interface"""
        display(self.main_widget)
        print("üìù Click 'Open Drawing Canvas' to start drawing digits!")
    
    def open_drawing_window(self, btn):
        """Open the Tkinter drawing window"""
        def run_drawing_app():
            self.root = tk.Tk()
            self.root.title("Draw Your Digit Here!")
            self.root.geometry("400x500")
            self.root.configure(bg='#f0f0f0')
            
            # Title
            title_label = tk.Label(self.root, 
                                  text="üñäÔ∏è Draw a Digit (0-9)",
                                  font=("Arial", 16, "bold"),
                                  bg='#f0f0f0')
            title_label.pack(pady=10)
            
            # Canvas
            self.canvas = tk.Canvas(self.root,
                                   width=self.canvas_size,
                                   height=self.canvas_size,
                                   bg='white',
                                   cursor='pencil')
            self.canvas.pack(pady=10)
            
            # Bind drawing events
            self.canvas.bind("<Button-1>", self.start_drawing)
            self.canvas.bind("<B1-Motion>", self.draw)
            self.canvas.bind("<ButtonRelease-1>", self.stop_drawing)
            
            # Control buttons
            control_frame = tk.Frame(self.root, bg='#f0f0f0')
            control_frame.pack(pady=10)
            
            clear_canvas_btn = tk.Button(control_frame,
                                        text="üóëÔ∏è Clear",
                                        command=self.clear_tkinter_canvas,
                                        font=("Arial", 12, "bold"),
                                        bg='#ff4444',
                                        fg='white',
                                        cursor='hand2')
            clear_canvas_btn.pack(side='left', padx=5)
            
            predict_canvas_btn = tk.Button(control_frame,
                                          text="üîç Predict",
                                          command=self.predict_from_canvas,
                                          font=("Arial", 12, "bold"),
                                          bg='#4CAF50',
                                          fg='white',
                                          cursor='hand2')
            predict_canvas_btn.pack(side='left', padx=5)
            
            close_btn = tk.Button(control_frame,
                                 text="‚ùå Close",
                                 command=self.root.destroy,
                                 font=("Arial", 12, "bold"),
                                 bg='#666666',
                                 fg='white',
                                 cursor='hand2')
            close_btn.pack(side='left', padx=5)
            
            # Instructions
            instructions = tk.Label(self.root,
                                   text="üí° Draw clearly in the center ‚Ä¢ Use thick strokes",
                                   font=("Arial", 10),
                                   bg='#f0f0f0',
                                   fg='#666666')
            instructions.pack(pady=5)
            
            # Center window
            self.root.update_idletasks()
            x = (self.root.winfo_screenwidth() - self.root.winfo_width()) // 2
            y = (self.root.winfo_screenheight() - self.root.winfo_height()) // 2
            self.root.geometry(f"+{x}+{y}")
            
            self.root.mainloop()
        
        # Run in separate thread to avoid blocking Jupyter
        threading.Thread(target=run_drawing_app, daemon=True).start()
    
    def start_drawing(self, event):
        """Start drawing on canvas"""
        self.drawing = True
        self.last_x = event.x
        self.last_y = event.y
    
    def draw(self, event):
        """Draw on canvas"""
        if self.drawing and self.last_x and self.last_y:
            # Draw on tkinter canvas
            self.canvas.create_line(self.last_x, self.last_y,
                                   event.x, event.y,
                                   width=self.brush_size,
                                   fill='black',
                                   capstyle=tk.ROUND,
                                   smooth=tk.TRUE)
            
            # Draw on PIL image
            self.pil_draw.line([self.last_x, self.last_y, event.x, event.y],
                              fill='black',
                              width=self.brush_size)
            
            self.last_x = event.x
            self.last_y = event.y
    
    def stop_drawing(self, event):
        """Stop drawing"""
        self.drawing = False
        self.last_x = None
        self.last_y = None
    
    def clear_tkinter_canvas(self):
        """Clear the Tkinter canvas"""
        if hasattr(self, 'canvas'):
            self.canvas.delete("all")
        self.pil_image = Image.new("RGB", (self.canvas_size, self.canvas_size), "white")
        self.pil_draw = ImageDraw.Draw(self.pil_image)
    
    def clear_canvas(self, btn=None):
        """Clear canvas and reset predictions"""
        self.pil_image = Image.new("RGB", (self.canvas_size, self.canvas_size), "white")
        self.pil_draw = ImageDraw.Draw(self.pil_image)
        
        # Reset widgets
        self.prediction_label.value = "<h3>Draw a digit to see prediction</h3>"
        self.confidence_label.value = "<p><strong>Confidence:</strong> --</p>"
        
        # Reset probability bars
        for i in range(10):
            self.prob_bars[i].value = 0
            self.prob_bars[i].bar_style = 'info'
            self.prob_labels[i].value = "0.0%"
        
        with self.output:
            clear_output(wait=True)
            print("‚úÖ Canvas cleared - ready for new digit!")
    
    def preprocess_image(self):
        """Preprocess the PIL image for model input"""
        # Convert to grayscale
        img = self.pil_image.convert('L')
        
        # Apply slight blur to smooth the drawing
        img = img.filter(ImageFilter.GaussianBlur(radius=1))
        
        # Resize to 28x28
        img = img.resize((28, 28), Image.Resampling.LANCZOS)
        
        # Convert to numpy array and normalize
        img_array = np.array(img)
        
        # Invert colors (white background to black, black drawing to white)
        img_array = 255 - img_array
        
        # Normalize to [0, 1]
        img_array = img_array.astype(np.float32) / 255.0
        
        # Apply MNIST normalization
        img_array = (img_array - 0.1307) / 0.3081
        
        # Convert to PyTorch tensor
        tensor = torch.FloatTensor(img_array).unsqueeze(0).unsqueeze(0)
        
        return tensor, img_array
    
    def predict_from_canvas(self):
        """Predict from the Tkinter canvas"""
        self.predict_digit(None)
    
    def predict_digit(self, btn=None):
        """Make prediction on the drawn digit"""
        try:
            # Preprocess the image
            input_tensor, processed_img = self.preprocess_image()
            input_tensor = input_tensor.to(self.device)
            
            # Make prediction
            with torch.no_grad():
                output = self.model(input_tensor)
                probabilities = torch.softmax(output, dim=1)
                predicted_digit = output.argmax(dim=1).item()
                confidence = probabilities[0][predicted_digit].item() * 100
            
            # Update prediction display
            color = '#2E8B57' if confidence > 70 else '#FF6347' if confidence > 40 else '#FF4500'
            self.prediction_label.value = f"<h3 style='color: {color}'>Predicted: {predicted_digit}</h3>"
            self.confidence_label.value = f"<p><strong>Confidence:</strong> <span style='color: {color}'>{confidence:.1f}%</span></p>"
            
            # Update probability bars
            probs = probabilities[0].cpu().numpy()
            for i in range(10):
                prob_percent = probs[i] * 100
                self.prob_bars[i].value = prob_percent
                self.prob_labels[i].value = f"{prob_percent:.1f}%"
                
                # Set bar color based on prediction
                if i == predicted_digit:
                    self.prob_bars[i].bar_style = 'success'
                else:
                    self.prob_bars[i].bar_style = 'info'
            
            with self.output:
                clear_output(wait=True)
                print(f"üéØ Prediction: {predicted_digit} (Confidence: {confidence:.1f}%)")
                
                # Display the processed image
                plt.figure(figsize=(8, 3))
                
                plt.subplot(1, 2, 1)
                plt.imshow(self.pil_image, cmap='gray')
                plt.title("Your Drawing", fontsize=12)
                plt.axis('off')
                
                plt.subplot(1, 2, 2)
                plt.imshow(processed_img, cmap='gray')
                plt.title(f"Processed (28x28)\\nPredicted: {predicted_digit}", fontsize=12)
                plt.axis('off')
                
                plt.tight_layout()
                plt.show()
        
        except Exception as e:
            with self.output:
                clear_output(wait=True)
                print(f"‚ùå Error making prediction: {str(e)}")
                print("Make sure you have drawn something on the canvas!")

## üöÄ Launch Interactive Drawing Interface

Now let's create and launch the interactive drawing interface!

In [None]:
# Create the interactive drawing interface
# Make sure you have a trained model first!

try:
    # Load the trained model (make sure you've run the training cells above)
    if 'model' in locals():
        print("‚úÖ Using the model from training...")
        drawing_app = InteractiveDrawingApp(model)
    else:
        print("üìÅ Loading saved model...")
        saved_model = load_model()
        drawing_app = InteractiveDrawingApp(saved_model)
    
    print("üé® Interactive Drawing Interface Ready!")
    print("üëÜ Use the interface below to draw digits and get predictions!")
    drawing_app.display()
    
except Exception as e:
    print(f"‚ùå Error setting up drawing interface: {str(e)}")
    print("üí° Make sure you have:")
    print("   1. Trained a model (run the training cells above)")
    print("   2. Or have a saved model file 'digit_recognizer.pth'")

## üì± How to Use the Interactive Interface

### Step-by-Step Instructions:

1. **Run the cell above** to initialize the drawing interface
2. **Click "Open Drawing Canvas"** to open a new drawing window
3. **Draw a digit (0-9)** using your mouse in the white canvas area
4. **Click "Predict"** to get the AI's prediction with confidence scores
5. **View results** both in the drawing window and in the notebook interface
6. **Click "Clear"** to reset and try drawing a new digit
7. **Try different digits** to test the model's accuracy!

### Features:
- ‚úèÔ∏è **Real-time Drawing**: Smooth drawing with mouse input
- üéØ **Instant Predictions**: Get predictions with confidence scores  
- üìä **Probability Bars**: See confidence levels for all digits (0-9)
- üñºÔ∏è **Image Processing**: View your drawing and how the AI sees it
- üîÑ **Easy Reset**: Clear canvas to try new digits
- üìà **Visual Feedback**: Color-coded predictions based on confidence

### Tips for Better Accuracy:
- Draw digits **large and centered** in the canvas
- Use **thick, clear strokes**
- Make sure digits are **well-formed** and recognizable
- Try **different writing styles** to test robustness