In [3]:
!pip install pillow ipywidgets



In [4]:
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageDraw, ImageFilter
import threading
from IPython.display import display, clear_output
import ipywidgets as widgets
from io import BytesIO
import base64

In [5]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [6]:
class DigitRecognizer(nn.Module):
    def __init__(self):
        super(DigitRecognizer, self).__init__()
        self.flatten = nn.Flatten()
        self.network = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        x = self.flatten(x)
        return self.network(x)

In [7]:
def load_model(path='digit_recognizer.pth'):
    """Load a pre-trained model from file"""
    model = DigitRecognizer()
    model.load_state_dict(torch.load(path, map_location='cpu', weights_only=True))
    model.eval()
    print(f"Model loaded from '{path}'")
    return model
def create_mock_model():
    """Create a simple mock model for testing purposes"""
    print("‚ö†Ô∏è Creating mock model for testing - predictions will be random!")
    print("üí° For real predictions, train a model using the main notebook first.")
    model = DigitRecognizer()
    return model

In [8]:
class InteractiveDrawingApp:
    def __init__(self, model):
        self.model = model
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.canvas_size = 280
        self.brush_size = 12
        self.drawing = False
        self.last_x = None
        self.last_y = None
        self.pil_image = Image.new("RGB", (self.canvas_size, self.canvas_size), "white")
        self.pil_draw = ImageDraw.Draw(self.pil_image)
        self.setup_widgets()
    
    
    def setup_widgets(self):
        self.output = widgets.Output()
        self.clear_btn = widgets.Button(
            description="üóëÔ∏è Clear Canvas",
            button_style='danger',
            layout=widgets.Layout(width='150px')
        )
        self.predict_btn = widgets.Button(
            description="üîç Predict Now",
            button_style='success',
            layout=widgets.Layout(width='150px')
        )
        self.draw_btn = widgets.Button(
            description="üñäÔ∏è Open Drawing Canvas",
            button_style='info',
            layout=widgets.Layout(width='200px')
        )
        
        self.clear_btn.on_click(self.clear_canvas)
        self.predict_btn.on_click(self.predict_digit)
        self.draw_btn.on_click(self.open_drawing_window)
        
        self.prediction_label = widgets.HTML(
            value="<h3>Draw a digit to see prediction</h3>",
            layout=widgets.Layout(width='300px')
        )
        
        self.confidence_label = widgets.HTML(
            value="<p><strong>Confidence:</strong> --</p>",
            layout=widgets.Layout(width='300px')
        )
        
        self.prob_bars = {}
        self.prob_labels = {}
        for i in range(10):
            self.prob_bars[i] = widgets.FloatProgress(
                value=0,
                min=0,
                max=100,
                description=f'Digit {i}:',
                bar_style='info',
                layout=widgets.Layout(width='250px')
            )
            self.prob_labels[i] = widgets.HTML(value="0.0%")
        
        button_box = widgets.HBox([self.draw_btn, self.clear_btn, self.predict_btn])
        prediction_box = widgets.VBox([self.prediction_label, self.confidence_label])
        
        prob_boxes = []
        for i in range(10):
            prob_box = widgets.HBox([self.prob_bars[i], self.prob_labels[i]])
            prob_boxes.append(prob_box)
        
        prob_container = widgets.VBox(prob_boxes)
        
        self.main_widget = widgets.VBox([
            button_box,
            widgets.HTML("<hr>"),
            prediction_box,
            widgets.HTML("<h4>All Digit Probabilities:</h4>"),
            prob_container,
            self.output
        ])
        
        
        
    
    def display(self):
        display(self.main_widget)
        print("üìù Click 'Open Drawing Canvas' to start drawing digits!")
    
    def open_drawing_window(self, btn):
        def run_drawing_app():
            self.root = tk.Tk()
            self.root.title("Draw Your Digit Here!")
            self.root.geometry("400x500")
            self.root.configure(bg='#f0f0f0')
            
            title_label = tk.Label(self.root, 
                                  text="üñäÔ∏è Draw a Digit (0-9)",
                                  font=("Arial", 16, "bold"),
                                  bg='#f0f0f0')
            title_label.pack(pady=10)
            
            self.canvas = tk.Canvas(self.root,
                                   width=self.canvas_size,
                                   height=self.canvas_size,
                                   bg='white',
                                   cursor='pencil')
            self.canvas.pack(pady=10)
            
            self.canvas.bind("<Button-1>", self.start_drawing)
            self.canvas.bind("<B1-Motion>", self.draw)
            self.canvas.bind("<ButtonRelease-1>", self.stop_drawing)
            
            control_frame = tk.Frame(self.root, bg='#f0f0f0')
            control_frame.pack(pady=10)
            
            clear_canvas_btn = tk.Button(control_frame,
                                        text="üóëÔ∏è Clear",
                                        command=self.clear_tkinter_canvas,
                                        font=("Arial", 12, "bold"),
                                        bg='#ff4444',
                                        fg='white',
                                        cursor='hand2')
            clear_canvas_btn.pack(side='left', padx=5)
            
            predict_canvas_btn = tk.Button(control_frame,
                                          text="üîç Predict",
                                          command=self.predict_from_canvas,
                                          font=("Arial", 12, "bold"),
                                          bg='#4CAF50',
                                          fg='white',
                                          cursor='hand2')
            predict_canvas_btn.pack(side='left', padx=5)
            
            close_btn = tk.Button(control_frame,
                                 text="‚ùå Close",
                                 command=self.root.destroy,
                                 font=("Arial", 12, "bold"),
                                 bg='#666666',
                                 fg='white',
                                 cursor='hand2')
            close_btn.pack(side='left', padx=5)
            
            instructions = tk.Label(self.root,
                                   text="üí° Draw clearly in the center ‚Ä¢ Use thick strokes",
                                   font=("Arial", 10),
                                   bg='#f0f0f0',
                                   fg='#666666')
            instructions.pack(pady=5)
            
            self.root.update_idletasks()
            x = (self.root.winfo_screenwidth() - self.root.winfo_width()) // 2
            y = (self.root.winfo_screenheight() - self.root.winfo_height()) // 2
            self.root.geometry(f"+{x}+{y}")
            
            self.root.mainloop()
        
        threading.Thread(target=run_drawing_app, daemon=True).start()
        
        
    
    def start_drawing(self, event):
        self.drawing = True
        self.last_x = event.x
        self.last_y = event.y
    
    def draw(self, event):
        if self.drawing and self.last_x and self.last_y:
            self.canvas.create_line(self.last_x, self.last_y,
                                   event.x, event.y,
                                   width=self.brush_size,
                                   fill='black',
                                   capstyle=tk.ROUND,
                                   smooth=tk.TRUE)
            self.pil_draw.line([self.last_x, self.last_y, event.x, event.y],
                              fill='black',
                              width=self.brush_size)
            
            self.last_x = event.x
            self.last_y = event.y
    
    def stop_drawing(self, event):
        self.drawing = False
        self.last_x = None
        self.last_y = None
    
    def clear_tkinter_canvas(self):
        if hasattr(self, 'canvas'):
            self.canvas.delete("all")
        self.pil_image = Image.new("RGB", (self.canvas_size, self.canvas_size), "white")
        self.pil_draw = ImageDraw.Draw(self.pil_image)
    
    def clear_canvas(self, btn=None):
        self.pil_image = Image.new("RGB", (self.canvas_size, self.canvas_size), "white")
        self.pil_draw = ImageDraw.Draw(self.pil_image)
        
        self.prediction_label.value = "<h3>Draw a digit to see prediction</h3>"
        self.confidence_label.value = "<p><strong>Confidence:</strong> --</p>"
        
        for i in range(10):
            self.prob_bars[i].value = 0
            self.prob_bars[i].bar_style = 'info'
            self.prob_labels[i].value = "0.0%"
        
        with self.output:
            clear_output(wait=True)
            print("‚úÖ Canvas cleared - ready for new digit!")
        
        
        
        
    def preprocess_image(self):
        img = self.pil_image.convert('L')
        img = img.filter(ImageFilter.GaussianBlur(radius=1))
        img = img.resize((28, 28), Image.Resampling.LANCZOS)
        img_array = np.array(img)
        img_array = 255 - img_array
        img_array = img_array.astype(np.float32) / 255.0
        img_array = (img_array - 0.1307) / 0.3081
        tensor = torch.FloatTensor(img_array).unsqueeze(0).unsqueeze(0)
        return tensor, img_array
        
        
    
    
    def predict_from_canvas(self):
        self.predict_digit(None)
    
    
    def predict_digit(self, btn=None):
        try:
            input_tensor, processed_img = self.preprocess_image()
            input_tensor = input_tensor.to(self.device)
            
            with torch.no_grad():
                output = self.model(input_tensor)
                probabilities = torch.softmax(output, dim=1)
                predicted_digit = output.argmax(dim=1).item()
                confidence = probabilities[0][predicted_digit].item() * 100
                
            color = '#2E8B57' if confidence > 70 else '#FF6347' if confidence > 40 else '#FF4500'
            self.prediction_label.value = f"<h3 style='color: {color}'>Predicted: {predicted_digit}</h3>"
            self.confidence_label.value = f"<p><strong>Confidence:</strong> <span style='color: {color}'>{confidence:.1f}%</span></p>"
            
            probs = probabilities[0].cpu().numpy()
            for i in range(10):
                prob_percent = probs[i] * 100
                self.prob_bars[i].value = prob_percent
                self.prob_labels[i].value = f"{prob_percent:.1f}%"
                
                # Set bar color based on prediction
                if i == predicted_digit:
                    self.prob_bars[i].bar_style = 'success'
                else:
                    self.prob_bars[i].bar_style = 'info'
            
            with self.output:
                clear_output(wait=True)
                print(f"üéØ Prediction: {predicted_digit} (Confidence: {confidence:.1f}%)")
                
                plt.figure(figsize=(8, 3))
                
                plt.subplot(1, 2, 1)
                plt.imshow(self.pil_image, cmap='gray')
                plt.title("Your Drawing", fontsize=12)
                plt.axis('off')
                
                plt.subplot(1, 2, 2)
                plt.imshow(processed_img, cmap='gray')
                plt.title(f"Processed (28x28)\\nPredicted: {predicted_digit}", fontsize=12)
                plt.axis('off')
                
                plt.tight_layout()
                plt.show()
            
        except Exception as e:
            with self.output:
                clear_output(wait=True)
                print(f"‚ùå Error making prediction: {str(e)}")
                print("Make sure you have drawn something on the canvas!")

In [None]:
try:
    if 'model' in locals():
        print("‚úÖ Using the model from training...")
        drawing_app = InteractiveDrawingApp(model)
    else:
        print("üìÅ Trying to load saved model...")
        try:
            saved_model = load_model()
            drawing_app = InteractiveDrawingApp(saved_model)
        except FileNotFoundError:
            print("‚ö†Ô∏è No saved model found. Creating mock model for interface testing...")
            mock_model = create_mock_model()
            drawing_app = InteractiveDrawingApp(mock_model)
        except Exception as load_error:
            print(f"‚ùå Error loading saved model: {str(load_error)}")
            print("‚ö†Ô∏è Creating mock model for interface testing...")
            mock_model = create_mock_model()
            drawing_app = InteractiveDrawingApp(mock_model)
    
    print("üé® Interactive Drawing Interface Ready!")
    print("üëÜ Use the interface below to draw digits and get predictions!")
    drawing_app.display()
    
except Exception as e:
    print(f"‚ùå Error setting up drawing interface: {str(e)}")
    print("üí° Make sure you have:")
    print("   1. Trained a model (run the training cells above)")
    print("   2. Or have a saved model file 'digit_recognizer.pth'")
    print("   3. All required packages installed (torch, PIL, etc.)")

üìÅ Trying to load saved model...
Model loaded from 'digit_recognizer.pth'
üé® Interactive Drawing Interface Ready!
üëÜ Use the interface below to draw digits and get predictions!


VBox(children=(HBox(children=(Button(button_style='info', description='üñäÔ∏è Open Drawing Canvas', layout=Layout(‚Ä¶

üìù Click 'Open Drawing Canvas' to start drawing digits!


# üé® Interactive Handwritten Digit Recognition

Welcome to the interactive drawing interface! This tool allows you to:

- **Draw digits** (0-9) directly on a canvas
- **Get real-time predictions** from your trained neural network
- **View confidence scores** for all possible digits
- **See how the AI processes** your handwritten input

## üöÄ How to Use:

1. **Run the cell below** to launch the interface
2. **Click "Open Drawing Canvas"** to open the drawing window
3. **Draw a digit** using your mouse in the white area
4. **Click "Predict"** to see what the AI thinks you drew
5. **Try different digits** and drawing styles!

---

In [None]:
def open_drawing_canvas():
    """Open the drawing canvas directly without using the widget button"""
    if 'drawing_app' in globals():
        def run_drawing_window():
            import tkinter as tk
            from tkinter import messagebox
            
            root = tk.Tk()
            root.title("Draw Your Digit Here!")
            root.geometry("400x500")
            root.configure(bg='#f0f0f0')
            
            title_label = tk.Label(root, 
                                  text="üñäÔ∏è Draw a Digit (0-9)",
                                  font=("Arial", 16, "bold"),
                                  bg='#f0f0f0')
            title_label.pack(pady=10)
            
            canvas = tk.Canvas(root,
                              width=280,
                              height=280,
                              bg='white',
                              cursor='pencil')
            canvas.pack(pady=10)
            
            drawing = {'active': False, 'last_x': None, 'last_y': None}
            
            def start_draw(event):
                drawing['active'] = True
                drawing['last_x'] = event.x
                drawing['last_y'] = event.y
            
            def draw_line(event):
                if drawing['active'] and drawing['last_x'] and drawing['last_y']:
                    canvas.create_line(drawing['last_x'], drawing['last_y'],
                                     event.x, event.y,
                                     width=12, fill='black',
                                     capstyle=tk.ROUND, smooth=tk.TRUE)
                    drawing['last_x'] = event.x
                    drawing['last_y'] = event.y
            
            def stop_draw(event):
                drawing['active'] = False
                drawing['last_x'] = None
                drawing['last_y'] = None

            canvas.bind("<Button-1>", start_draw)
            canvas.bind("<B1-Motion>", draw_line)
            canvas.bind("<ButtonRelease-1>", stop_draw)
            
            button_frame = tk.Frame(root, bg='#f0f0f0')
            button_frame.pack(pady=10)
            
            clear_btn = tk.Button(button_frame, text="üóëÔ∏è Clear",
                                 command=lambda: canvas.delete("all"),
                                 font=("Arial", 12, "bold"),
                                 bg='#ff4444', fg='white')
            clear_btn.pack(side='left', padx=5)
            
            close_btn = tk.Button(button_frame, text="‚ùå Close",
                                 command=root.destroy,
                                 font=("Arial", 12, "bold"),
                                 bg='#666666', fg='white')
            close_btn.pack(side='left', padx=5)
            
            instructions = tk.Label(root,
                                   text="üí° Draw clearly in the center ‚Ä¢ Use thick strokes",
                                   font=("Arial", 10),
                                   bg='#f0f0f0', fg='#666666')
            instructions.pack(pady=5)
            
            root.update_idletasks()
            x = (root.winfo_screenwidth() - root.winfo_width()) // 2
            y = (root.winfo_screenheight() - root.winfo_height()) // 2
            root.geometry(f"+{x}+{y}")
            
            print("‚úÖ Drawing canvas opened! Draw a digit and close the window when done.")
            root.mainloop()

        import threading
        threading.Thread(target=run_drawing_window, daemon=True).start()
    else:
        print("‚ùå Drawing app not initialized. Please run the setup cell first.")
def check_interface():
    """Check if the drawing interface is properly loaded"""
    if 'drawing_app' in globals():
        print("‚úÖ Drawing app is loaded and ready!")
        print("üìù You can use: open_drawing_canvas() to open the drawing window")
        try:
            print(f"üéØ Model device: {drawing_app.device}")
            print("üé® Interface components loaded successfully")
        except:
            print("‚ö†Ô∏è Some interface components may have issues")
    else:
        print("‚ùå Drawing app not found. Please run the setup cell first.")

print("üõ†Ô∏è Helper functions loaded!")
print("üìù Use: open_drawing_canvas() to open drawing window")
print("üîç Use: check_interface() to check if everything is working")

üõ†Ô∏è Helper functions loaded!
üìù Use: open_drawing_canvas() to open drawing window
üîç Use: check_interface() to check if everything is working


In [15]:

print("üöÄ Opening drawing canvas...")
open_drawing_canvas()

üöÄ Opening drawing canvas...
‚úÖ Drawing canvas opened! Draw a digit and close the window when done.
