In [3]:
from typing import List, Optional
import urllib.request
from tqdm.auto import tqdm
from pathlib import Path
import requests
import torch
import math
import numpy as np
import os
import glob

def get_quickdraw_class_names():
    url = "https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt"
    r = requests.get(url)
    classes = [x.replace(' ', '_') for x in r.text.splitlines()]
    return classes


def download_quickdraw_dataset(root="../QuickDraw", limit: Optional[int] = None, class_names: List[str]=None):
    if class_names is None:
        class_names = get_quickdraw_class_names()

    root = Path(root)
    root.mkdir(exist_ok=True, parents=True)
    url = 'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/'

    print("Downloading Quickdraw Dataset...")
    for class_name in tqdm(class_names[:limit]):
        fpath = root / f"{class_name}.npy"
        if not fpath.exists():
            urllib.request.urlretrieve(f"{url}{class_name.replace('_', '%20')}.npy", fpath)


def load_quickdraw_data(root="../QuickDraw", max_items_per_class=5000):
    all_files = Path(root).glob('*.npy')

    x = np.empty([0, 784], dtype=np.uint8)
    y = np.empty([0], dtype=int)
    class_names = []

    print(f"Loading {max_items_per_class} examples for each class from the Quickdraw Dataset...")
    for idx, file in enumerate(tqdm(sorted(all_files))):
        data = np.load(file, mmap_mode='r')
        data = data[0: max_items_per_class, :]
        labels = np.full(data.shape[0], idx)
        x = np.concatenate((x, data), axis=0)
        y = np.append(y, labels)

        class_names.append(file.stem)

    return x, y, class_names

class QuickDrawDataset(torch.utils.data.Dataset):
    def __init__(self, root, max_items_per_class=5000, class_limit=None, class_names = None):
        super().__init__()
        self.root = root
        self.max_items_per_class = max_items_per_class
        self.class_limit = class_limit
        self.class_names = class_names

        download_quickdraw_dataset(self.root, self.class_limit, self.class_names)
        self.X, self.Y, self.classes = load_quickdraw_data(self.root, self.max_items_per_class)

    def __getitem__(self, idx):
        x = (self.X[idx] / 255.).astype(np.float32).reshape(1, 28, 28)
        y = self.Y[idx]

        return torch.from_numpy(x), y.item()

    def __len__(self):
        return len(self.X)

    def collate_fn(self, batch):
        x = torch.stack([item[0] for item in batch])
        y = torch.LongTensor([item[1] for item in batch])
        return {'pixel_values': x, 'labels': y}
    
    def split(self, pct=0.1):
        num_classes = len(self.classes)
        indices = torch.randperm(len(self)).tolist()
        n_val = math.floor(len(indices) * pct)
        train_ds = torch.utils.data.Subset(self, indices[:-n_val])
        val_ds = torch.utils.data.Subset(self, indices[-n_val:])
        return train_ds, val_ds

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path

# Import QuickDrawDataset from cell 1
from __main__ import QuickDrawDataset

class QuickDrawCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

if __name__ == "__main__":
    # Settings
    root = "../QuickDraw"
    class_limit = 10  # e.g., train on 10 classes for speed
    max_items_per_class = 2000

    # Prepare dataset
    dataset = QuickDrawDataset(root, max_items_per_class=max_items_per_class, class_limit=class_limit)
    # TODO: select class names from https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt
    train_ds, val_ds = dataset.split(pct=0.1)
    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=dataset.collate_fn)
    val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=dataset.collate_fn)

    # Save class names for later use in the app
    class_names = dataset.classes
    Path(root).mkdir(exist_ok=True)
    with open(Path(root) / "class_names.txt", "w") as f:
        f.write("\n".join(class_names))

    # Instantiate model, loss, optimizer
    model = QuickDrawCNN(num_classes=len(class_names))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    for epoch in range(3):
        print(f"Epoch {epoch + 1}")
        loop = tqdm(train_loader, desc="Training", leave=False)
        for batch in loop:
            images = batch['pixel_values']
            labels = batch['labels']
            preds = model(images)
            loss = criterion(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item())
        print(f"Epoch {epoch + 1} done")

    torch.save(model.state_dict(), Path(root) / "quickdraw_model.pth")
    print("Model and class names saved.")

In [None]:
# 9-layer CNN (from https://github.com/nateraw/quickdraw-pytorch)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from pathlib import Path

# Import QuickDrawDataset from cell 1
from __main__ import QuickDrawDataset

class QuickDrawCNN9L(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(2304, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return x

if __name__ == "__main__":
    # Settings
    root = "../QuickDraw"
    class_limit = 10  # e.g., train on 10 classes for speed
    max_items_per_class = 2000

    # Prepare dataset
    dataset = QuickDrawDataset(root, max_items_per_class=max_items_per_class, class_limit=class_limit, 
                               class_names=["apple", "The_Eiffel_Tower", "cat", "smiley_face", "sun", "toothbrush", "pizza", "hedgehog", "lighthouse", "ice_cream"])
    # TODO: select class names from https://raw.githubusercontent.com/googlecreativelab/quickdraw-dataset/master/categories.txt
    train_ds, val_ds = dataset.split(pct=0.1)
    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, collate_fn=dataset.collate_fn)
    val_loader = DataLoader(val_ds, batch_size=64, shuffle=False, collate_fn=dataset.collate_fn)

    # Save class names for later use in the app
    class_names = dataset.classes
    Path(root).mkdir(exist_ok=True)
    with open(Path(root) / "class_names.txt", "w") as f:
        f.write("\n".join(class_names))

    # Instantiate model, loss, optimizer
    model = QuickDrawCNN9L(num_classes=len(class_names))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    for epoch in range(3):
        print(f"Epoch {epoch + 1}")
        loop = tqdm(train_loader, desc="Training", leave=False)
        for batch in loop:
            images = batch['pixel_values']
            labels = batch['labels']
            preds = model(images)
            loss = criterion(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item())
        print(f"Epoch {epoch + 1} done")

    torch.save(model.state_dict(), Path(root) / "quickdraw_model_9layer.pth")
    print("Model and class names saved.")

In [None]:
# draw then predict
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageDraw, ImageOps
import torch
import numpy as np
from pathlib import Path

# Import your CNN model
from __main__ import QuickDrawCNN9L

class DrawingApp:
    def __init__(self, root):
        root.title("QuickDraw Classifier")

        # Set up frame
        self.mainframe = ttk.Frame(root, padding="10")
        self.mainframe.grid(row=0, column=0, sticky="nsew")

        # Create canvas for drawing
        self.canvas = tk.Canvas(self.mainframe, width=280, height=280, bg='gold', relief='solid', bd=2)
        self.canvas.grid(row=0, column=0, columnspan=3)

        # Create buttons and status label
        self.button_predict = ttk.Button(self.mainframe, text="Predict", command=self.predict)
        self.button_predict.grid(row=1, column=0, pady=10)

        self.button_clear = ttk.Button(self.mainframe, text="Clear", command=self.clear_canvas)
        self.button_clear.grid(row=1, column=1, pady=10)

        self.button_info = ttk.Button(self.mainframe, text="Info", command=self.show_info)
        self.button_info.grid(row=1, column=2, pady=10)

        self.label_status = ttk.Label(self.mainframe, text="Draw something!", anchor="w")
        self.label_status.grid(row=2, column=0, columnspan=3, sticky="w")

        # Create the image and drawing objects
        self.image = Image.new("L", (280, 280), color=255)  # White background
        self.draw = ImageDraw.Draw(self.image)

        # Load class names
        root_dir = "../QuickDraw"
        with open(Path(root_dir) / "class_names.txt", "r") as f:
            self.class_names = [line.strip() for line in f if line.strip()]

        # Load the model
        self.model = QuickDrawCNN9L(num_classes=len(self.class_names))
        self.model.load_state_dict(torch.load(Path(root_dir) / "quickdraw_model_9layer.pth", map_location='cpu'))
        self.model.eval()

        # Bind the canvas for drawing
        self.canvas.bind("<B1-Motion>", self.on_paint)

    def show_info(self):
        """Show available classes in a popup window."""
        info_win = tk.Toplevel()
        info_win.title("Available Classes")
        info_text = tk.Text(info_win, width=40, height=20)
        info_text.grid(padx=10, pady=10)
        info_text.insert("end", "\n".join(self.class_names))
        info_text.config(state="disabled")

    def on_paint(self, event):
        x, y = event.x, event.y
        r = 5
        self.canvas.create_oval(x-r, y-r, x+r, y+r, fill='black', outline='black')
        self.draw.ellipse([x-r, y-r, x+r, y+r], fill=0)

    def clear_canvas(self):
        self.canvas.delete("all")
        self.image = Image.new("L", (280, 280), color=255)
        self.draw = ImageDraw.Draw(self.image)
        self.label_status.config(text="Draw something!")

    def predict(self):
        resized = self.image.resize((28, 28), Image.LANCZOS)
        inverted = ImageOps.invert(resized)
        tensor = torch.tensor(np.array(inverted) / 255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        with torch.no_grad():
            pred = self.model(tensor)
            idx = pred.argmax(dim=1).item()
            class_name = self.class_names[idx]
            self.label_status.config(text=f"Prediction: {class_name}")

if __name__ == "__main__":
    root = tk.Tk()
    app = DrawingApp(root)
    root.mainloop()

RuntimeError: Error(s) in loading state_dict for QuickDrawCNN:
	Missing key(s) in state_dict: "conv.0.weight", "conv.0.bias", "conv.3.weight", "conv.3.bias", "fc.0.weight", "fc.0.bias", "fc.2.weight", "fc.2.bias". 
	Unexpected key(s) in state_dict: "model.0.weight", "model.0.bias", "model.3.weight", "model.3.bias", "model.6.weight", "model.6.bias", "model.10.weight", "model.10.bias", "model.12.weight", "model.12.bias". 

In [14]:
# real-time prediction
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageDraw, ImageOps
import torch
import numpy as np
from pathlib import Path

# Import your CNN model
# from __main__ import QuickDrawCNN

class DrawingApp:
    def __init__(self, root):
        self.root = root
        self.root.title("QuickDraw Classifier")

        # Set up frame
        self.mainframe = ttk.Frame(root, padding="10")
        self.mainframe.grid(row=0, column=0, sticky="nsew")

        # Create canvas for drawing
        self.canvas = tk.Canvas(self.mainframe, width=280, height=280, bg='gold', relief='solid', bd=2)
        self.canvas.grid(row=0, column=0, columnspan=3)

        # Create buttons and status label
        self.button_predict = ttk.Button(self.mainframe, text="Predict", command=self.predict)
        self.button_predict.grid(row=1, column=0, pady=10)

        self.button_clear = ttk.Button(self.mainframe, text="Clear", command=self.clear_canvas)
        self.button_clear.grid(row=1, column=1, pady=10)

        self.button_info = ttk.Button(self.mainframe, text="Info", command=self.show_info)
        self.button_info.grid(row=1, column=2, pady=10)

        self.label_status = ttk.Label(self.mainframe, text="Draw something!", anchor="w")
        self.label_status.grid(row=2, column=0, columnspan=3, sticky="w")

        # Create the image and drawing objects
        self.image = Image.new("L", (280, 280), color=255)  # White background
        self.draw = ImageDraw.Draw(self.image)

        # Load class names
        root_dir = "../QuickDraw"
        with open(Path(root_dir) / "class_names.txt", "r") as f:
            self.class_names = [line.strip() for line in f if line.strip()]

        # Load the model
        self.model = QuickDrawCNN9L(num_classes=len(self.class_names))
        self.model.load_state_dict(torch.load(Path(root_dir) / "quickdraw_model_9layer.pth", map_location='cpu'))
        self.model.eval()

        # Bind the canvas for drawing
        self.canvas.bind("<B1-Motion>", self.on_paint)

        # Start real-time prediction loop
        self.root.after(200, self.predict)

    def show_info(self):
        """Show available classes in a popup window."""
        info_win = tk.Toplevel()
        info_win.title("Available Classes")
        info_text = tk.Text(info_win, width=40, height=20)
        info_text.grid(padx=10, pady=10)
        info_text.insert("end", "\n".join(self.class_names))
        info_text.config(state="disabled")

    def on_paint(self, event):
        x, y = event.x, event.y
        r = 5
        self.canvas.create_oval(x-r, y-r, x+r, y+r, fill='black', outline='black')
        self.draw.ellipse([x-r, y-r, x+r, y+r], fill=0)

    def clear_canvas(self):
        self.canvas.delete("all")
        self.image = Image.new("L", (280, 280), color=255)
        self.draw = ImageDraw.Draw(self.image)
        self.label_status.config(text="Draw something!")

    def predict(self):
        resized = self.image.resize((28, 28), Image.LANCZOS)
        inverted = ImageOps.invert(resized)
        tensor = torch.tensor(np.array(inverted) / 255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        with torch.no_grad():
            pred = self.model(tensor)
            idx = pred.argmax(dim=1).item()
            class_name = self.class_names[idx]
            self.label_status.config(text=f"Prediction: {class_name}")
        self.root.after(200, self.predict)

if __name__ == "__main__":
    root = tk.Tk()
    app = DrawingApp(root)
    root.mainloop()

In [None]:
# add confidence score and second guess
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageDraw, ImageOps
import torch
import numpy as np
from pathlib import Path

# Import your CNN model
# from __main__ import QuickDrawCNN

class DrawingApp:
    def __init__(self, root):
        self.root = root
        self.root.title("QuickDraw Classifier")

        # Set up frame
        self.mainframe = ttk.Frame(root, padding="10")
        self.mainframe.grid(row=0, column=0, sticky="nsew")

        # Create canvas for drawing
        self.canvas = tk.Canvas(self.mainframe, width=280, height=280, bg='gold', relief='solid', bd=2)
        self.canvas.grid(row=0, column=0, columnspan=3)

        # Create buttons and status label
        self.button_predict = ttk.Button(self.mainframe, text="Predict", command=self.predict)
        self.button_predict.grid(row=1, column=0, pady=10)

        self.button_clear = ttk.Button(self.mainframe, text="Clear", command=self.clear_canvas)
        self.button_clear.grid(row=1, column=1, pady=10)

        self.button_info = ttk.Button(self.mainframe, text="Info", command=self.show_info)
        self.button_info.grid(row=1, column=2, pady=10)

        self.label_status = ttk.Label(self.mainframe, text="Draw something!", anchor="w")
        self.label_status.grid(row=2, column=0, columnspan=3, sticky="w")

        # Create the image and drawing objects
        self.image = Image.new("L", (280, 280), color=255)  # White background
        self.draw = ImageDraw.Draw(self.image)

        # Load class names
        root_dir = "../QuickDraw"
        with open(Path(root_dir) / "class_names.txt", "r") as f:
            self.class_names = [line.strip() for line in f if line.strip()]

        # Load the model
        self.model = QuickDrawCNN9L(num_classes=len(self.class_names))
        self.model.load_state_dict(torch.load(Path(root_dir) / "quickdraw_model_9layer.pth", map_location='cpu'))
        self.model.eval()

        # Bind the canvas for drawing
        self.canvas.bind("<B1-Motion>", self.on_paint)

        # Start real-time prediction loop
        self.root.after(200, self.predict)

    def show_info(self):
        """Show available classes in a popup window."""
        info_win = tk.Toplevel()
        info_win.title("Available Classes")
        info_text = tk.Text(info_win, width=40, height=20)
        info_text.grid(padx=10, pady=10)
        info_text.insert("end", "\n".join(self.class_names))
        info_text.config(state="disabled")

    def on_paint(self, event):
        x, y = event.x, event.y
        r = 5
        self.canvas.create_oval(x-r, y-r, x+r, y+r, fill='black', outline='black')
        self.draw.ellipse([x-r, y-r, x+r, y+r], fill=0)

    def clear_canvas(self):
        self.canvas.delete("all")
        self.image = Image.new("L", (280, 280), color=255)
        self.draw = ImageDraw.Draw(self.image)
        self.label_status.config(text="Draw something!")

    def predict(self):
        resized = self.image.resize((28, 28), Image.LANCZOS)
        inverted = ImageOps.invert(resized)
        tensor = torch.tensor(np.array(inverted) / 255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        with torch.no_grad():
            pred = self.model(tensor)
            probs = torch.softmax(pred, dim=1)
            top2 = torch.topk(probs, 2)
            idx1 = top2.indices[0][0].item()
            idx2 = top2.indices[0][1].item()
            conf1 = top2.values[0][0].item()
            conf2 = top2.values[0][1].item()
            class1 = self.class_names[idx1]
            class2 = self.class_names[idx2]
            self.label_status.config(
                text=f"Prediction: {class1} ({conf1:.2%})\nSecond guess: {class2} ({conf2:.2%})"
            )
        self.root.after(200, self.predict)

if __name__ == "__main__":
    root = tk.Tk()
    app = DrawingApp(root)
    root.mainloop()

In [None]:
# need to draw specified name in time limit
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageDraw, ImageOps
import torch
import numpy as np
from pathlib import Path
import random

# Import your CNN model
from __main__ import QuickDrawCNN9L

class DrawingApp:
    def __init__(self, root):
        self.root = root
        self.root.title("QuickDraw Classifier")

        # Set up frame
        self.mainframe = ttk.Frame(root, padding="10")
        self.mainframe.grid(row=0, column=0, sticky="nsew")

        # Create canvas for drawing
        self.canvas = tk.Canvas(self.mainframe, width=280, height=280, bg='gold', relief='solid', bd=2)
        self.canvas.grid(row=0, column=0, columnspan=4)

        # Create buttons and status label
        self.button_predict = ttk.Button(self.mainframe, text="Predict", command=self.show_prediction)
        self.button_predict.grid(row=1, column=0, pady=10)

        self.button_clear = ttk.Button(self.mainframe, text="Clear", command=self.clear_canvas)
        self.button_clear.grid(row=1, column=1, pady=10)

        self.button_info = ttk.Button(self.mainframe, text="Info", command=self.show_info)
        self.button_info.grid(row=1, column=2, pady=10)

        self.button_start = ttk.Button(self.mainframe, text="Start Round", command=self.start_round)
        self.button_start.grid(row=1, column=3, pady=10)

        self.label_status = ttk.Label(self.mainframe, text="Draw something!", anchor="w")
        self.label_status.grid(row=2, column=0, columnspan=4, sticky="w")

        # Timer label
        self.label_timer = ttk.Label(self.mainframe, text="", anchor="w")
        self.label_timer.grid(row=3, column=0, columnspan=4, sticky="w")

        # Create the image and drawing objects
        self.image = Image.new("L", (280, 280), color=255)  # White background
        self.draw = ImageDraw.Draw(self.image)

        # Load class names
        root_dir = "../QuickDraw"
        with open(Path(root_dir) / "class_names.txt", "r") as f:
            self.class_names = [line.strip() for line in f if line.strip()]

        # Load the model
        self.model = QuickDrawCNN9L(num_classes=len(self.class_names))
        self.model.load_state_dict(torch.load(Path(root_dir) / "quickdraw_model_9layer.pth", map_location='cpu'))
        self.model.eval()

        # Bind the canvas for drawing
        self.canvas.bind("<B1-Motion>", self.on_paint)

        # Game state
        self.current_target = None
        self.time_left = 0
        self.timer_running = False

    def show_info(self):
        """Show available classes in a popup window."""
        info_win = tk.Toplevel()
        info_win.title("Available Classes")
        info_text = tk.Text(info_win, width=40, height=20)
        info_text.grid(padx=10, pady=10)
        info_text.insert("end", "\n".join(self.class_names))
        info_text.config(state="disabled")

    def clear_canvas(self):
        self.canvas.delete("all")
        self.image = Image.new("L", (280, 280), color=255)
        self.draw = ImageDraw.Draw(self.image)
        self.label_status.config(text="Draw something!")

    def start_round(self):
        self.clear_canvas()
        self.current_target = random.choice(self.class_names)
        self.label_status.config(text=f"Draw: {self.current_target}")
        self.time_left = 10  # seconds
        self.timer_running = True
        self.update_timer()
        self.root.after(200, self.check_prediction)

    def update_timer(self):
        if self.time_left > 0 and self.timer_running:
            self.label_timer.config(text=f"Time left: {self.time_left}s")
            self.time_left -= 1
            self.root.after(1000, self.update_timer)
        elif self.timer_running:
            self.label_timer.config(text="Time's up!")
            self.timer_running = False
            self.show_prediction()

    def check_prediction(self):
        if not self.timer_running:
            return
        resized = self.image.resize((28, 28), Image.LANCZOS)
        inverted = ImageOps.invert(resized)
        tensor = torch.tensor(np.array(inverted) / 255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        with torch.no_grad():
            pred = self.model(tensor)
            probs = torch.softmax(pred, dim=1)
            top2 = torch.topk(probs, 2)
            idx1 = top2.indices[0][0].item()
            idx2 = top2.indices[0][1].item()
            conf1 = top2.values[0][0].item()
            conf2 = top2.values[0][1].item()
            class1 = self.class_names[idx1]
            class2 = self.class_names[idx2]
            self.label_status.config(
                text=f"Target: {self.current_target}\nPrediction: {class1} ({conf1:.2%})\nSecond guess: {class2} ({conf2:.2%})"
            )
            # Stop if prediction matches target and confidence >= 60%
            if class1 == self.current_target and conf1 >= 0.6:
                self.label_timer.config(text="I recognized it!")
                self.timer_running = False
                return
        self.root.after(200, self.check_prediction)

    def show_prediction(self):
        resized = self.image.resize((28, 28), Image.LANCZOS)
        inverted = ImageOps.invert(resized)
        tensor = torch.tensor(np.array(inverted) / 255.0, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        with torch.no_grad():
            pred = self.model(tensor)
            probs = torch.softmax(pred, dim=1)
            top2 = torch.topk(probs, 2)
            idx1 = top2.indices[0][0].item()
            idx2 = top2.indices[0][1].item()
            conf1 = top2.values[0][0].item()
            conf2 = top2.values[0][1].item()
            class1 = self.class_names[idx1]
            class2 = self.class_names[idx2]
            self.label_status.config(
                text=f"Target: {self.current_target}\nPrediction: {class1} ({conf1:.2%})\nSecond guess: {class2} ({conf2:.2%})"
            )

    def on_paint(self, event):
        if self.timer_running:
            x, y = event.x, event.y
            r = 5
            self.canvas.create_oval(x-r, y-r, x+r, y+r, fill='black', outline='black')
            self.draw.ellipse([x-r, y-r, x+r, y+r], fill=0)

if __name__ == "__main__":
    root = tk.Tk()
    app = DrawingApp(root)
    root.mainloop()