In [3]:
!pip install pillow ipywidgets



In [1]:
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageDraw, ImageFilter
import threading
from IPython.display import display, clear_output
import ipywidgets as widgets
from io import BytesIO
import base64

In [None]:
class InteractiveDrawingApp:
    
    def setup_widgets(self):
        self.output = widgets.Output()
    
    def display(self):
        display(self.main_widget)
        print("üìù Click 'Open Drawing Canvas' to start drawing digits!")
    
    def open_drawing_window(self, btn):
        def run_drawing_app():
            self.root = tk.Tk()
            self.root.title("Draw Your Digit Here!")
            self.root.geometry("400x500")
            self.root.configure(bg='#f0f0f0')
            
            title_label = tk.Label(self.root, 
                                  text="üñäÔ∏è Draw a Digit (0-9)",
                                  font=("Arial", 16, "bold"),
                                  bg='#f0f0f0')
            title_label.pack(pady=10)
            
            self.canvas = tk.Canvas(self.root,
                                   width=self.canvas_size,
                                   height=self.canvas_size,
                                   bg='white',
                                   cursor='pencil')
            self.canvas.pack(pady=10)
            
            self.canvas.bind("<Button-1>", self.start_drawing)
            self.canvas.bind("<B1-Motion>", self.draw)
            self.canvas.bind("<ButtonRelease-1>", self.stop_drawing)
            
            control_frame = tk.Frame(self.root, bg='#f0f0f0')
            control_frame.pack(pady=10)
            
            clear_canvas_btn = tk.Button(control_frame,
                                        text="üóëÔ∏è Clear",
                                        command=self.clear_tkinter_canvas,
                                        font=("Arial", 12, "bold"),
                                        bg='#ff4444',
                                        fg='white',
                                        cursor='hand2')
            clear_canvas_btn.pack(side='left', padx=5)
            
            predict_canvas_btn = tk.Button(control_frame,
                                          text="üîç Predict",
                                          command=self.predict_from_canvas,
                                          font=("Arial", 12, "bold"),
                                          bg='#4CAF50',
                                          fg='white',
                                          cursor='hand2')
            predict_canvas_btn.pack(side='left', padx=5)
            
            close_btn = tk.Button(control_frame,
                                 text="‚ùå Close",
                                 command=self.root.destroy,
                                 font=("Arial", 12, "bold"),
                                 bg='#666666',
                                 fg='white',
                                 cursor='hand2')
            close_btn.pack(side='left', padx=5)
            
            instructions = tk.Label(self.root,
                                   text="üí° Draw clearly in the center ‚Ä¢ Use thick strokes",
                                   font=("Arial", 10),
                                   bg='#f0f0f0',
                                   fg='#666666')
            instructions.pack(pady=5)
            
            self.root.update_idletasks()
            x = (self.root.winfo_screenwidth() - self.root.winfo_width()) // 2
            y = (self.root.winfo_screenheight() - self.root.winfo_height()) // 2
            self.root.geometry(f"+{x}+{y}")
            
            self.root.mainloop()
    threading.Thread(target=run_drawing_app, daemon=True).start()
        
        
    
    def start_drawing(self, event):
        self.drawing = True
        self.last_x = event.x
        self.last_y = event.y
    
    def draw(self, event):
        if self.drawing and self.last_x and self.last_y:
            self.canvas.create_line(self.last_x, self.last_y,
                                   event.x, event.y,
                                   width=self.brush_size,
                                   fill='black',
                                   capstyle=tk.ROUND,
                                   smooth=tk.TRUE)
            self.pil_draw.line([self.last_x, self.last_y, event.x, event.y],
                              fill='black',
                              width=self.brush_size)
            
            self.last_x = event.x
            self.last_y = event.y
    
    def stop_drawing(self, event):
        self.drawing = False
        self.last_x = None
        self.last_y = None
    
    def clear_tkinter_canvas(self):
        if hasattr(self, 'canvas'):
            self.canvas.delete("all")
        self.pil_image = Image.new("RGB", (self.canvas_size, self.canvas_size), "white")
        self.pil_draw = ImageDraw.Draw(self.pil_image)
    
    def clear_canvas(self, btn=None):
        self.pil_image = Image.new("RGB", (self.canvas_size, self.canvas_size), "white")
        self.pil_draw = ImageDraw.Draw(self.pil_image)
        
        self.prediction_label.value = "<h3>Draw a digit to see prediction</h3>"
        self.confidence_label.value = "<p><strong>Confidence:</strong> --</p>"
        
        for i in range(10):
            self.prob_bars[i].value = 0
            self.prob_bars[i].bar_style = 'info'
            self.prob_labels[i].value = "0.0%"
        
        with self.output:
            clear_output(wait=True)
            print("‚úÖ Canvas cleared - ready for new digit!")
        
        
        
        
    def preprocess_image(self):
        img = self.pil_image.convert('L')
        img = img.filter(ImageFilter.GaussianBlur(radius=1))
        img = img.resize((28, 28), Image.Resampling.LANCZOS)
        img_array = np.array(img)
        img_array = 255 - img_array
        img_array = img_array.astype(np.float32) / 255.0
        img_array = (img_array - 0.1307) / 0.3081
        tensor = torch.FloatTensor(img_array).unsqueeze(0).unsqueeze(0)
        return tensor, img_array
        
        
    
    
    def predict_from_canvas(self):
        self.predict_digit(None)
    
    
    def predict_digit(self, btn=None):
        try:
            input_tensor, processed_img = self.preprocess_image()
            input_tensor = input_tensor.to(self.device)
            
            with torch.no_grad():
                output = self.model(input_tensor)
                probabilities = torch.softmax(output, dim=1)
                predicted_digit = output.argmax(dim=1).item()
                confidence = probabilities[0][predicted_digit].item() * 100
                
            color = '#2E8B57' if confidence > 70 else '#FF6347' if confidence > 40 else '#FF4500'
            self.prediction_label.value = f"<h3 style='color: {color}'>Predicted: {predicted_digit}</h3>"
            self.confidence_label.value = f"<p><strong>Confidence:</strong> <span style='color: {color}'>{confidence:.1f}%</span></p>"
            
            probs = probabilities[0].cpu().numpy()
            for i in range(10):
                prob_percent = probs[i] * 100
                self.prob_bars[i].value = prob_percent
                self.prob_labels[i].value = f"{prob_percent:.1f}%"
                
                # Set bar color based on prediction
                if i == predicted_digit:
                    self.prob_bars[i].bar_style = 'success'
                else:
                    self.prob_bars[i].bar_style = 'info'
            
            with self.output:
                clear_output(wait=True)
                print(f"üéØ Prediction: {predicted_digit} (Confidence: {confidence:.1f}%)")
            
            
            with self.output:
                clear_output(wait=True)
                print(f"üéØ Prediction: {predicted_digit} (Confidence: {confidence:.1f}%)")
                
                plt.figure(figsize=(8, 3))
                
                plt.subplot(1, 2, 1)
                plt.imshow(self.pil_image, cmap='gray')
                plt.title("Your Drawing", fontsize=12)
                plt.axis('off')
                
                plt.subplot(1, 2, 2)
                plt.imshow(processed_img, cmap='gray')
                plt.title(f"Processed (28x28)\\nPredicted: {predicted_digit}", fontsize=12)
                plt.axis('off')
                
                plt.tight_layout()
                plt.show()
            
            except Exception as e:
                with self.output:
                    clear_output(wait=True)
                    print(f"‚ùå Error making prediction: {str(e)}")
                    print("Make sure you have drawn something on the canvas!")
                
            
    
    