# Real-time Sketch Inference with Lightweight Models

This notebook demonstrates how to use SqueezeNet1_1 and MobileNetV3-Small for real-time inference on a sketch drawn on an interactive canvas.
We will:
1. Load pre-trained SqueezeNet1_1 and MobileNetV3-Small models and adapt them as feature extractors.
2. Train simple Logistic Regression classifiers on features extracted from a small subset of the QuickDraw dataset (using local binary files).
3. Provide an interactive canvas for the user to draw a sketch.
4. Perform inference on the drawn sketch and display the predicted category.

**Prerequisites:**
- Ensure you have the QuickDraw `.bin` files (e.g., `full_binary_apple.bin`, `full_binary_cat.bin`) in the `./data` directory (or update `BINARY_DATA_ROOT`). The naming convention should match `full_binary_{category}.bin`.
- Required libraries: `torch`, `torchvision`, `scikit-learn`, `Pillow`, `numpy`, `matplotlib`, `ipywidgets`, `tqdm`, `ipycanvas`, `joblib`.
  Install them if you haven't: `pip install torch torchvision scikit-learn Pillow numpy matplotlib ipywidgets tqdm ipycanvas joblib`

In [None]:
!pip install -q ipywidgets ipycanvas joblib

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.metrics import accuracy_score

import numpy as np
from PIL import Image, ImageDraw, UnidentifiedImageError
import os
import struct
from struct import unpack # Explicit import
import time
from tqdm.notebook import tqdm # Use tqdm.notebook for Jupyter

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
from ipycanvas import Canvas, hold_canvas # For interactive drawing
import io # For handling image bytes from upload
from joblib import dump, load as joblib_load # For saving/loading classifiers
import requests # Added for downloading
import urllib.parse # Added for URL encoding category names

## 1. Configuration and Device Setup

In [None]:
# Configuration
QUICKDRAW_CATEGORIES = [
    'apple', 'cat', 'dog', 'door', 'elephant', 'fish', 'flower', 'grapes',
    'grass', 'house', 'ice cream', 'jail', 'key', 'lion', 'moon', 'nose',
    'pencil', 'rabbit', 'sun', 'tree', 'umbrella', 'van', 'cake', 'airplane',
    'ant', 'banana', 'bed', 'bee', 'bicycle', 'bird', 'book', 'bread', 'bus',
    'elbow', 'ear', 'camera', 'car', 'chair', 'clock', 'cloud', 'hand',
    'computer', 'cookie', 'cow', 'crayon', 'cup', 'eraser', 'carrot', 'drums',
    'eye', 'knife'
]
# For faster demo, we'll use fewer samples to train the classifiers in this notebook
NUM_SAMPLES_PER_CATEGORY_FOR_CLASSIFIER = 100
NUM_SAMPLES_PER_CATEGORY_FOR_BENCHMARK = 50 # Number of samples per category for testing
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BINARY_DATA_ROOT = './data' # Point to local data directory
CANVAS_SIZE = 256 # For the drawing canvas
CLASSIFIER_FILENAME_TEMPLATE = "quickdraw_classifier_{model_name}.joblib" # For saving/loading classifiers

print(f"Using device: {DEVICE}")
print(f"QuickDraw categories for classifier: {', '.join(QUICKDRAW_CATEGORIES)}")
print(f"Samples per category for classifier training: {NUM_SAMPLES_PER_CATEGORY_FOR_CLASSIFIER}")
print(f"Binary data root: {os.path.abspath(BINARY_DATA_ROOT)}")

if not os.path.exists(BINARY_DATA_ROOT):
    os.makedirs(BINARY_DATA_ROOT, exist_ok=True)
    print(f"Warning: Data directory '{BINARY_DATA_ROOT}' was not found and has been created.")
    print(f"Please ensure QuickDraw .bin files (e.g., full_binary_apple.bin) for categories {QUICKDRAW_CATEGORIES} are placed there.")

Using device: cuda
QuickDraw categories for classifier: apple, cat, dog, door, elephant, fish, flower, grapes, grass, house, ice cream, jail, key, lion, moon, nose, pencil, rabbit, sun, tree, umbrella, van, cake, airplane, ant, banana, bed, bee, bicycle, bird, book, bread, bus, elbow, ear, camera, car, chair, clock, cloud, hand, computer, cookie, cow, crayon, cup, eraser, carrot, drums, eye, knife
Samples per category for classifier training: 100
Binary data root: /content/data
Please ensure QuickDraw .bin files (e.g., full_binary_apple.bin) for categories ['apple', 'cat', 'dog', 'door', 'elephant', 'fish', 'flower', 'grapes', 'grass', 'house', 'ice cream', 'jail', 'key', 'lion', 'moon', 'nose', 'pencil', 'rabbit', 'sun', 'tree', 'umbrella', 'van', 'cake', 'airplane', 'ant', 'banana', 'bed', 'bee', 'bicycle', 'bird', 'book', 'bread', 'bus', 'elbow', 'ear', 'camera', 'car', 'chair', 'clock', 'cloud', 'hand', 'computer', 'cookie', 'cow', 'crayon', 'cup', 'eraser', 'carrot', 'drums', 'eye

## 1.5 Download QuickDraw Binary Data
This cell will download the .bin files for the specified categories if they are not already present.


In [None]:
# %%
def download_quickdraw_binary(category_name, download_dir):
    """
    Downloads the .bin file for a given QuickDraw category.
    Files are named 'full_binary_{category_name_underscored}.bin'.
    """
    # Sanitize category name for filename (replace spaces with underscores)
    filename_category_part = category_name.replace(' ', '_')
    local_filename = f"full_binary_{filename_category_part}.bin"
    local_filepath = os.path.join(download_dir, local_filename)

    if os.path.exists(local_filepath):
        print(f"File for '{category_name}' already exists: {local_filepath}")
        return

    # URL encode category name for the download URL (e.g., "ice cream" -> "ice%20cream")
    url_category_part = urllib.parse.quote(category_name)
    url = f"https://storage.googleapis.com/quickdraw_dataset/full/binary/{url_category_part}.bin"

    print(f"Downloading '{category_name}' from {url} to {local_filepath}...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise an exception for HTTP errors

        total_size = int(response.headers.get('content-length', 0))

        with open(local_filepath, 'wb') as f, tqdm(
            desc=category_name,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=8192):
                size = f.write(chunk)
                bar.update(size)
        print(f"Successfully downloaded '{category_name}'.")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading '{category_name}': {e}")
        if os.path.exists(local_filepath): # Clean up partial download
            os.remove(local_filepath)
    except Exception as e:
        print(f"An unexpected error occurred while downloading '{category_name}': {e}")
        if os.path.exists(local_filepath): # Clean up partial download
            os.remove(local_filepath)


print(f"Starting download process for {len(QUICKDRAW_CATEGORIES)} categories into '{BINARY_DATA_ROOT}'...")
for category in QUICKDRAW_CATEGORIES:
    download_quickdraw_binary(category, BINARY_DATA_ROOT)
print("Download process finished.")


Starting download process for 51 categories into './data'...
Downloading 'apple' from https://storage.googleapis.com/quickdraw_dataset/full/binary/apple.bin to ./data/full_binary_apple.bin...


apple:   0%|          | 0.00/13.2M [00:00<?, ?iB/s]

Successfully downloaded 'apple'.
Downloading 'cat' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cat.bin to ./data/full_binary_cat.bin...


cat:   0%|          | 0.00/18.7M [00:00<?, ?iB/s]

Successfully downloaded 'cat'.
Downloading 'dog' from https://storage.googleapis.com/quickdraw_dataset/full/binary/dog.bin to ./data/full_binary_dog.bin...


dog:   0%|          | 0.00/22.4M [00:00<?, ?iB/s]

Successfully downloaded 'dog'.
Downloading 'door' from https://storage.googleapis.com/quickdraw_dataset/full/binary/door.bin to ./data/full_binary_door.bin...


door:   0%|          | 0.00/8.15M [00:00<?, ?iB/s]

Successfully downloaded 'door'.
Downloading 'elephant' from https://storage.googleapis.com/quickdraw_dataset/full/binary/elephant.bin to ./data/full_binary_elephant.bin...


elephant:   0%|          | 0.00/17.9M [00:00<?, ?iB/s]

Successfully downloaded 'elephant'.
Downloading 'fish' from https://storage.googleapis.com/quickdraw_dataset/full/binary/fish.bin to ./data/full_binary_fish.bin...


fish:   0%|          | 0.00/11.6M [00:00<?, ?iB/s]

Successfully downloaded 'fish'.
Downloading 'flower' from https://storage.googleapis.com/quickdraw_dataset/full/binary/flower.bin to ./data/full_binary_flower.bin...


flower:   0%|          | 0.00/20.5M [00:00<?, ?iB/s]

Successfully downloaded 'flower'.
Downloading 'grapes' from https://storage.googleapis.com/quickdraw_dataset/full/binary/grapes.bin to ./data/full_binary_grapes.bin...


grapes:   0%|          | 0.00/31.4M [00:00<?, ?iB/s]

Successfully downloaded 'grapes'.
Downloading 'grass' from https://storage.googleapis.com/quickdraw_dataset/full/binary/grass.bin to ./data/full_binary_grass.bin...


grass:   0%|          | 0.00/11.3M [00:00<?, ?iB/s]

Successfully downloaded 'grass'.
Downloading 'house' from https://storage.googleapis.com/quickdraw_dataset/full/binary/house.bin to ./data/full_binary_house.bin...


house:   0%|          | 0.00/10.4M [00:00<?, ?iB/s]

Successfully downloaded 'house'.
Downloading 'ice cream' from https://storage.googleapis.com/quickdraw_dataset/full/binary/ice%20cream.bin to ./data/full_binary_ice_cream.bin...


ice cream:   0%|          | 0.00/11.3M [00:00<?, ?iB/s]

Successfully downloaded 'ice cream'.
Downloading 'jail' from https://storage.googleapis.com/quickdraw_dataset/full/binary/jail.bin to ./data/full_binary_jail.bin...


jail:   0%|          | 0.00/11.0M [00:00<?, ?iB/s]

Successfully downloaded 'jail'.
Downloading 'key' from https://storage.googleapis.com/quickdraw_dataset/full/binary/key.bin to ./data/full_binary_key.bin...


key:   0%|          | 0.00/16.1M [00:00<?, ?iB/s]

Successfully downloaded 'key'.
Downloading 'lion' from https://storage.googleapis.com/quickdraw_dataset/full/binary/lion.bin to ./data/full_binary_lion.bin...


lion:   0%|          | 0.00/23.5M [00:00<?, ?iB/s]

Successfully downloaded 'lion'.
Downloading 'moon' from https://storage.googleapis.com/quickdraw_dataset/full/binary/moon.bin to ./data/full_binary_moon.bin...


moon:   0%|          | 0.00/11.1M [00:00<?, ?iB/s]

Successfully downloaded 'moon'.
Downloading 'nose' from https://storage.googleapis.com/quickdraw_dataset/full/binary/nose.bin to ./data/full_binary_nose.bin...


nose:   0%|          | 0.00/13.2M [00:00<?, ?iB/s]

Successfully downloaded 'nose'.
Downloading 'pencil' from https://storage.googleapis.com/quickdraw_dataset/full/binary/pencil.bin to ./data/full_binary_pencil.bin...


pencil:   0%|          | 0.00/9.62M [00:00<?, ?iB/s]

Successfully downloaded 'pencil'.
Downloading 'rabbit' from https://storage.googleapis.com/quickdraw_dataset/full/binary/rabbit.bin to ./data/full_binary_rabbit.bin...


rabbit:   0%|          | 0.00/24.1M [00:00<?, ?iB/s]

Successfully downloaded 'rabbit'.
Downloading 'sun' from https://storage.googleapis.com/quickdraw_dataset/full/binary/sun.bin to ./data/full_binary_sun.bin...


sun:   0%|          | 0.00/15.1M [00:00<?, ?iB/s]

Successfully downloaded 'sun'.
Downloading 'tree' from https://storage.googleapis.com/quickdraw_dataset/full/binary/tree.bin to ./data/full_binary_tree.bin...


tree:   0%|          | 0.00/19.6M [00:00<?, ?iB/s]

Successfully downloaded 'tree'.
Downloading 'umbrella' from https://storage.googleapis.com/quickdraw_dataset/full/binary/umbrella.bin to ./data/full_binary_umbrella.bin...


umbrella:   0%|          | 0.00/11.0M [00:00<?, ?iB/s]

Successfully downloaded 'umbrella'.
Downloading 'van' from https://storage.googleapis.com/quickdraw_dataset/full/binary/van.bin to ./data/full_binary_van.bin...


van:   0%|          | 0.00/20.7M [00:00<?, ?iB/s]

Successfully downloaded 'van'.
Downloading 'cake' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cake.bin to ./data/full_binary_cake.bin...


cake:   0%|          | 0.00/17.0M [00:00<?, ?iB/s]

Successfully downloaded 'cake'.
Downloading 'airplane' from https://storage.googleapis.com/quickdraw_dataset/full/binary/airplane.bin to ./data/full_binary_airplane.bin...


airplane:   0%|          | 0.00/15.0M [00:00<?, ?iB/s]

Successfully downloaded 'airplane'.
Downloading 'ant' from https://storage.googleapis.com/quickdraw_dataset/full/binary/ant.bin to ./data/full_binary_ant.bin...


ant:   0%|          | 0.00/17.7M [00:00<?, ?iB/s]

Successfully downloaded 'ant'.
Downloading 'banana' from https://storage.googleapis.com/quickdraw_dataset/full/binary/banana.bin to ./data/full_binary_banana.bin...


banana:   0%|          | 0.00/23.9M [00:00<?, ?iB/s]

Successfully downloaded 'banana'.
Downloading 'bed' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bed.bin to ./data/full_binary_bed.bin...


bed:   0%|          | 0.00/10.5M [00:00<?, ?iB/s]

Successfully downloaded 'bed'.
Downloading 'bee' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bee.bin to ./data/full_binary_bee.bin...


bee:   0%|          | 0.00/19.6M [00:00<?, ?iB/s]

Successfully downloaded 'bee'.
Downloading 'bicycle' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bicycle.bin to ./data/full_binary_bicycle.bin...


bicycle:   0%|          | 0.00/17.6M [00:00<?, ?iB/s]

Successfully downloaded 'bicycle'.
Downloading 'bird' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bird.bin to ./data/full_binary_bird.bin...


bird:   0%|          | 0.00/16.4M [00:00<?, ?iB/s]

Successfully downloaded 'bird'.
Downloading 'book' from https://storage.googleapis.com/quickdraw_dataset/full/binary/book.bin to ./data/full_binary_book.bin...


book:   0%|          | 0.00/13.5M [00:00<?, ?iB/s]

Successfully downloaded 'book'.
Downloading 'bread' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bread.bin to ./data/full_binary_bread.bin...


bread:   0%|          | 0.00/9.28M [00:00<?, ?iB/s]

Successfully downloaded 'bread'.
Downloading 'bus' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bus.bin to ./data/full_binary_bus.bin...


bus:   0%|          | 0.00/24.1M [00:00<?, ?iB/s]

Successfully downloaded 'bus'.
Downloading 'elbow' from https://storage.googleapis.com/quickdraw_dataset/full/binary/elbow.bin to ./data/full_binary_elbow.bin...


elbow:   0%|          | 0.00/9.12M [00:00<?, ?iB/s]

Successfully downloaded 'elbow'.
Downloading 'ear' from https://storage.googleapis.com/quickdraw_dataset/full/binary/ear.bin to ./data/full_binary_ear.bin...


ear:   0%|          | 0.00/9.70M [00:00<?, ?iB/s]

Successfully downloaded 'ear'.
Downloading 'camera' from https://storage.googleapis.com/quickdraw_dataset/full/binary/camera.bin to ./data/full_binary_camera.bin...


camera:   0%|          | 0.00/13.0M [00:00<?, ?iB/s]

Successfully downloaded 'camera'.
Downloading 'car' from https://storage.googleapis.com/quickdraw_dataset/full/binary/car.bin to ./data/full_binary_car.bin...


car:   0%|          | 0.00/23.3M [00:00<?, ?iB/s]

Successfully downloaded 'car'.
Downloading 'chair' from https://storage.googleapis.com/quickdraw_dataset/full/binary/chair.bin to ./data/full_binary_chair.bin...


chair:   0%|          | 0.00/16.8M [00:00<?, ?iB/s]

Successfully downloaded 'chair'.
Downloading 'clock' from https://storage.googleapis.com/quickdraw_dataset/full/binary/clock.bin to ./data/full_binary_clock.bin...


clock:   0%|          | 0.00/12.3M [00:00<?, ?iB/s]

Successfully downloaded 'clock'.
Downloading 'cloud' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cloud.bin to ./data/full_binary_cloud.bin...


cloud:   0%|          | 0.00/12.7M [00:00<?, ?iB/s]

Successfully downloaded 'cloud'.
Downloading 'hand' from https://storage.googleapis.com/quickdraw_dataset/full/binary/hand.bin to ./data/full_binary_hand.bin...


hand:   0%|          | 0.00/29.7M [00:00<?, ?iB/s]

Successfully downloaded 'hand'.
Downloading 'computer' from https://storage.googleapis.com/quickdraw_dataset/full/binary/computer.bin to ./data/full_binary_computer.bin...


computer:   0%|          | 0.00/13.6M [00:00<?, ?iB/s]

Successfully downloaded 'computer'.
Downloading 'cookie' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cookie.bin to ./data/full_binary_cookie.bin...


cookie:   0%|          | 0.00/19.7M [00:00<?, ?iB/s]

Successfully downloaded 'cookie'.
Downloading 'cow' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cow.bin to ./data/full_binary_cow.bin...


cow:   0%|          | 0.00/24.4M [00:00<?, ?iB/s]

Successfully downloaded 'cow'.
Downloading 'crayon' from https://storage.googleapis.com/quickdraw_dataset/full/binary/crayon.bin to ./data/full_binary_crayon.bin...


crayon:   0%|          | 0.00/10.8M [00:00<?, ?iB/s]

Successfully downloaded 'crayon'.
Downloading 'cup' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cup.bin to ./data/full_binary_cup.bin...


cup:   0%|          | 0.00/12.1M [00:00<?, ?iB/s]

Successfully downloaded 'cup'.
Downloading 'eraser' from https://storage.googleapis.com/quickdraw_dataset/full/binary/eraser.bin to ./data/full_binary_eraser.bin...


eraser:   0%|          | 0.00/10.4M [00:00<?, ?iB/s]

Successfully downloaded 'eraser'.
Downloading 'carrot' from https://storage.googleapis.com/quickdraw_dataset/full/binary/carrot.bin to ./data/full_binary_carrot.bin...


carrot:   0%|          | 0.00/13.5M [00:00<?, ?iB/s]

Successfully downloaded 'carrot'.
Downloading 'drums' from https://storage.googleapis.com/quickdraw_dataset/full/binary/drums.bin to ./data/full_binary_drums.bin...


drums:   0%|          | 0.00/18.8M [00:00<?, ?iB/s]

Successfully downloaded 'drums'.
Downloading 'eye' from https://storage.googleapis.com/quickdraw_dataset/full/binary/eye.bin to ./data/full_binary_eye.bin...


eye:   0%|          | 0.00/16.1M [00:00<?, ?iB/s]

Successfully downloaded 'eye'.
Downloading 'knife' from https://storage.googleapis.com/quickdraw_dataset/full/binary/knife.bin to ./data/full_binary_knife.bin...


knife:   0%|          | 0.00/12.0M [00:00<?, ?iB/s]

Successfully downloaded 'knife'.
Download process finished.


## 2. Model Definitions (Feature Extractors)
These are adapted from your `test.py` script.

In [None]:
class SqueezeNetFeatureExtractor(nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.features = original_model.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten(1)
    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = self.flatten(x)
        return x

class MobileNetV3FeatureExtractor(nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.features = original_model.features
        self.avgpool = original_model.avgpool
        self.flatten = nn.Flatten(1)
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        return x

# Store models and their configurations
MODELS_FOR_INFERENCE = {
    "SqueezeNet1_1": {
        "weights": models.SqueezeNet1_1_Weights.IMAGENET1K_V1,
        "model_fn": models.squeezenet1_1,
        "feature_extractor_fn": lambda m: SqueezeNetFeatureExtractor(m),
        "classifier": None,
        "transform": None,
        "feature_extractor_instance": None,
    },
    "MobileNetV3-Small": {
        "weights": models.MobileNet_V3_Small_Weights.IMAGENET1K_V1,
        "model_fn": models.mobilenet_v3_small,
        "feature_extractor_fn": lambda m: MobileNetV3FeatureExtractor(m),
        "classifier": None,
        "transform": None,
        "feature_extractor_instance": None,
    }
}

## 3. QuickDraw Data Loading Utilities
Adapted from your `test.py` for loading data to train the classifiers.

In [None]:
# --- QuickDraw Binary Data Reading Functions (adapted from test.py) ---
def unpack_drawing(file_handle):
    try:
        key_id, = unpack('Q', file_handle.read(8))
        country_code, = unpack('2s', file_handle.read(2))
        recognized, = unpack('b', file_handle.read(1))
        timestamp, = unpack('I', file_handle.read(4))
        n_strokes, = unpack('H', file_handle.read(2))
        image_strokes = []
        for _ in range(n_strokes):
            n_points, = unpack('H', file_handle.read(2))
            fmt = str(n_points) + 'B'
            if n_points == 0: # Handle empty stroke
                image_strokes.append(((), ()))
                continue
            # Check if enough bytes are available before reading
            # Peek to check without consuming; file_handle.peek() returns bytes, so check its length
            peeked_bytes = file_handle.peek(n_points * 2)
            if len(peeked_bytes) < n_points * 2 and n_points > 0:
                 # print(f"Warning: Not enough data for stroke points. Expected {n_points*2}, got {len(peeked_bytes)}. Skipping drawing.")
                 return None # Signal to skip this drawing
            x = unpack(fmt, file_handle.read(n_points))
            y = unpack(fmt, file_handle.read(n_points))
            image_strokes.append((x, y))
        return {
            'key_id': key_id,
            'country_code': country_code,
            'recognized': recognized,
            'timestamp': timestamp,
            'image': image_strokes # This is drawing_strokes
        }
    except struct.error: # End of file or malformed data
        return None
    except EOFError: # Explicitly handle EOF
        return None
    except Exception as e:
        # print(f"An error occurred during unpacking a single drawing: {e}") # Optional: for debugging
        return None


def unpack_drawings(filename):
    if not os.path.exists(filename):
        # print(f"File not found for unpacking: {filename}") # Verbose, can be removed
        return
    # No tqdm here for notebook simplicity, can be added back if long loading times
    with open(filename, 'rb') as f:
        while True:
            drawing_data = unpack_drawing(f)
            if drawing_data is None:
                break
            if drawing_data.get('recognized', False): # Only yield recognized drawings
                yield drawing_data


# --- Custom QuickDraw Dataset from Local Binary Files (adapted from test.py) ---
class QuickDrawBinaryDataset(Dataset):
    IMAGE_SIZE = (256, 256)
    LINE_WIDTH = 2

    def __init__(self, root_dir, category_name, transform=None, max_items=None, category_label=0, offset=0): # Added offset
        self.root_dir = root_dir
        self.category_name = category_name.replace(' ', '_') # Ensure filename-safe
        self.transform = transform
        self.filepath = os.path.join(self.root_dir, f"full_binary_{self.category_name}.bin")

        self.drawings_data = []
        self.category_label = category_label
        self.offset = offset # Store offset
        self.max_items = max_items # Store max_items

        # print(f"Attempting to load: {self.filepath} with offset {self.offset} and max_items {self.max_items}") # For debugging

        items_iterated = 0
        items_loaded = 0
        for drawing_data_dict in unpack_drawings(self.filepath):
            if items_iterated < self.offset:
                items_iterated += 1
                continue

            self.drawings_data.append(drawing_data_dict)
            items_loaded += 1
            items_iterated += 1 # total items iterated over from source after offset

            if self.max_items is not None and items_loaded >= self.max_items:
                break

        if not self.drawings_data and os.path.exists(self.filepath):
            print(f"Warning: No recognized drawings loaded for category {self.category_name} from {self.filepath} (offset: {self.offset}, max_items: {self.max_items}).")
        elif not os.path.exists(self.filepath):
            print(f"Warning: File not found for category {self.category_name}: {self.filepath}")


    def _render_drawing_to_image(self, drawing_strokes): # drawing_strokes is item['image']
        image = Image.new("L", self.IMAGE_SIZE, "white")
        draw = ImageDraw.Draw(image)
        for stroke_x, stroke_y in drawing_strokes: # Iterate through (x_coords, y_coords) tuples
            if not stroke_x or not stroke_y: # Skip empty strokes
                continue
            if len(stroke_x) == 1: # Single point
                # Ensure coordinates are within image bounds
                x_coord = min(max(stroke_x[0], 0), self.IMAGE_SIZE[0]-1)
                y_coord = min(max(stroke_y[0], 0), self.IMAGE_SIZE[1]-1)
                # Draw a small circle for a single point for visibility
                radius = self.LINE_WIDTH // 2 if self.LINE_WIDTH > 1 else 1
                draw.ellipse([(x_coord-radius, y_coord-radius),
                              (x_coord+radius, y_coord+radius)], fill="black")
            else: # Multiple points, draw lines
                points = []
                for i in range(len(stroke_x)):
                    x_coord = min(max(stroke_x[i], 0), self.IMAGE_SIZE[0]-1)
                    y_coord = min(max(stroke_y[i], 0), self.IMAGE_SIZE[1]-1)
                    points.append((x_coord, y_coord))
                if len(points) > 1:
                    draw.line(points, fill="black", width=self.LINE_WIDTH)
        return image

    def __len__(self):
        return len(self.drawings_data)

    def __getitem__(self, idx):
        drawing_data_dict = self.drawings_data[idx]
        # 'image' key in drawing_data_dict holds the list of strokes
        pil_image = self._render_drawing_to_image(drawing_data_dict['image'])

        label = self.category_label # Use the stored numeric label

        if self.transform:
            pil_image = self.transform(pil_image)

        return pil_image, label

# --- Feature Extraction Function ---
def extract_features_for_classifier(feature_extractor_instance, dataloader, device, description="Extracting features"):
    feature_extractor_instance.eval()
    feature_extractor_instance.to(device)
    features_list = []
    labels_list = []

    if len(dataloader) == 0:
        print(f"Warning: Dataloader for '{description}' is empty.")
        return np.array([]), np.array([])

    for inputs, labels in tqdm(dataloader, desc=description, leave=False):
        inputs = inputs.to(device)
        with torch.no_grad():
            outputs = feature_extractor_instance(inputs)
        features_list.append(outputs.cpu().detach().numpy())
        labels_list.append(labels.cpu().detach().numpy() if isinstance(labels, torch.Tensor) else np.array(labels))

    if not features_list:
        return np.array([]), np.array([])

    features_array = np.concatenate(features_list, axis=0)
    labels_array = np.concatenate(labels_list, axis=0)
    return features_array, labels_array

## 4. Prepare Models and Train Classifiers
This function will load data, extract features, and train a Logistic Regression classifier for each model.
It will also prepare the feature extractors and transforms for later inference.

In [None]:
def prepare_models_and_train_classifiers(force_retrain=False):
    global MODELS_FOR_INFERENCE

    any_classifier_trained_this_session = False
    for model_name, config in MODELS_FOR_INFERENCE.items():
        print(f"\n--- Preparing model and classifier for: {model_name} ---")

        classifier_path = CLASSIFIER_FILENAME_TEMPLATE.format(model_name=model_name.replace('-', '_'))

        # 1. Load pre-trained model and create feature extractor instance
        print("Loading pre-trained model...")
        weights = config["weights"]
        base_model = config["model_fn"](weights=weights)
        # Store the instantiated feature extractor
        MODELS_FOR_INFERENCE[model_name]["feature_extractor_instance"] = config["feature_extractor_fn"](base_model).to(DEVICE).eval()

        # 2. Get model-specific transforms
        base_model_transform = weights.transforms()
        quickdraw_transform = T.Compose([
            T.Grayscale(num_output_channels=3),
            base_model_transform
        ])
        MODELS_FOR_INFERENCE[model_name]["transform"] = quickdraw_transform

        # 3. Attempt to load classifier or train if needed/forced
        if not force_retrain and os.path.exists(classifier_path):
            try:
                MODELS_FOR_INFERENCE[model_name]["classifier"] = joblib_load(classifier_path)
                print(f"Loaded trained classifier for {model_name} from {classifier_path}")
                continue # Move to next model
            except Exception as e:
                print(f"Could not load classifier for {model_name} from {classifier_path}: {e}. Retraining.")

        print(f"Training new classifier for {model_name}...")
        # 4. Load QuickDraw data (small subset for training classifier)
        print("Loading QuickDraw data for classifier training...")
        all_datasets_for_training = []
        for i, category_name_str in enumerate(QUICKDRAW_CATEGORIES): # i will be the label
            dataset = QuickDrawBinaryDataset(
                root_dir=BINARY_DATA_ROOT,
                category_name=category_name_str,
                transform=quickdraw_transform, # Apply transform at dataset level
                max_items=NUM_SAMPLES_PER_CATEGORY_FOR_CLASSIFIER,
                category_label=i
            )
            if len(dataset) > 0:
                 all_datasets_for_training.append(dataset)

        if not all_datasets_for_training:
            print(f"No data loaded for any category for model {model_name}. Cannot train classifier.")
            MODELS_FOR_INFERENCE[model_name]["classifier"] = None # Ensure it's None
            continue

        classifier_train_dataset = ConcatDataset(all_datasets_for_training)
        if len(classifier_train_dataset) == 0:
            print(f"Combined dataset is empty for {model_name}. Cannot train classifier.")
            MODELS_FOR_INFERENCE[model_name]["classifier"] = None # Ensure it's None
            continue

        print(f"Total samples for {model_name} classifier training: {len(classifier_train_dataset)}")
        num_workers = min(os.cpu_count(), 2) if os.cpu_count() is not None else 0
        train_loader = DataLoader(classifier_train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)

        # 5. Extract features using the stored feature_extractor_instance
        print("Extracting features for classifier training...")
        current_feature_extractor = MODELS_FOR_INFERENCE[model_name]["feature_extractor_instance"]
        train_features, train_labels = extract_features_for_classifier(current_feature_extractor, train_loader, DEVICE,
                                                        description=f"Extracting features ({model_name})")

        if train_features.size == 0:
            print(f"No features extracted for {model_name}. Cannot train classifier.")
            MODELS_FOR_INFERENCE[model_name]["classifier"] = None # Ensure it's None
            continue

        # 6. Train Logistic Regression classifier
        print("Training Logistic Regression classifier...")
        classifier = make_pipeline(StandardScaler(), LogisticRegression(max_iter=2000, random_state=42, C=0.1, solver='liblinear')) # Increased max_iter
        try:
            classifier.fit(train_features, train_labels)
            MODELS_FOR_INFERENCE[model_name]["classifier"] = classifier
            any_classifier_trained_this_session = True

            # Save the trained classifier
            dump(classifier, classifier_path)
            print(f"Saved trained classifier for {model_name} to {classifier_path}")

            predictions = classifier.predict(train_features)
            accuracy = accuracy_score(train_labels, predictions)
            print(f"Classifier training set accuracy for {model_name}: {accuracy:.4f}")
        except Exception as e:
            print(f"Error fitting classifier for {model_name}: {e}")
            MODELS_FOR_INFERENCE[model_name]["classifier"] = None


    if any_classifier_trained_this_session:
        print("\n--- Classifier training/loading process completed. ---")
    else:
        print("\n--- All classifiers loaded from disk or no new classifiers were trained. ---")

    # Final check if any model is ready
    ready_models = [name for name, conf in MODELS_FOR_INFERENCE.items() if conf["classifier"] and conf["feature_extractor_instance"] and conf["transform"]]
    if not ready_models:
        print("\nWARNING: No models are ready for inference. Check data paths and training process.")
    else:
        print(f"\nModels ready for inference: {', '.join(ready_models)}")

## 5. Run Model Preparation and Classifier Training
This will load pre-trained models, set up feature extractors, and either load existing classifiers or train new ones.
Set `force_retrain=True` to retrain all classifiers even if saved versions exist.

In [None]:
# Set force_retrain to True if you want to ignore saved classifiers and train fresh ones
FORCE_RETRAIN_CLASSIFIERS = False
prepare_models_and_train_classifiers(force_retrain=FORCE_RETRAIN_CLASSIFIERS)


--- Preparing model and classifier for: SqueezeNet1_1 ---
Loading pre-trained model...
Training new classifier for SqueezeNet1_1...
Loading QuickDraw data for classifier training...
Total samples for SqueezeNet1_1 classifier training: 2282
Extracting features for classifier training...


Extracting features (SqueezeNet1_1):   0%|          | 0/72 [00:00<?, ?it/s]

Training Logistic Regression classifier...
Saved trained classifier for SqueezeNet1_1 to quickdraw_classifier_SqueezeNet1_1.joblib
Classifier training set accuracy for SqueezeNet1_1: 0.9706

--- Preparing model and classifier for: MobileNetV3-Small ---
Loading pre-trained model...
Training new classifier for MobileNetV3-Small...
Loading QuickDraw data for classifier training...
Total samples for MobileNetV3-Small classifier training: 2282
Extracting features for classifier training...


Extracting features (MobileNetV3-Small):   0%|          | 0/72 [00:00<?, ?it/s]

Training Logistic Regression classifier...
Saved trained classifier for MobileNetV3-Small to quickdraw_classifier_MobileNetV3_Small.joblib
Classifier training set accuracy for MobileNetV3-Small: 0.9904

--- Classifier training/loading process completed. ---

Models ready for inference: SqueezeNet1_1, MobileNetV3-Small


## 6. Image Preprocessing and Inference Logic for Sketches

In [None]:
# --- Output area for image and predictions ---
inference_output_area = widgets.Output()

def _preprocess_pil_sketch_for_inference(pil_image_L, model_specific_transform, display_image=True):
    """
    Processes a grayscale PIL image (from canvas or upload) for inference.
    Resizes, potentially inverts, displays it, and applies the target_transform.
    Returns the processed tensor or None if error.
    """
    try:
        # Ensure image is L mode (grayscale)
        if pil_image_L.mode != 'L':
            pil_image_L = pil_image_L.convert('L')

        # Resize to a standard size (QuickDrawBinaryDataset.IMAGE_SIZE is (256,256))
        # Model transforms usually handle final sizing, but initial standardization is good.
        processed_image = pil_image_L.resize(QuickDrawBinaryDataset.IMAGE_SIZE, Image.Resampling.LANCZOS)

        # Invert if it's white sketch on black background (QuickDraw is black on white)
        img_array = np.array(processed_image)
        if img_array.mean() < 128: # Heuristic: if average pixel value is dark
            processed_image = Image.eval(processed_image, lambda x: 255 - x)

        if display_image:
            plt.figure(figsize=(3,3)) # Smaller display for notebook
            plt.imshow(processed_image, cmap='gray')
            plt.title("Processed Sketch for Inference")
            plt.axis('off')
            plt.show()

        return model_specific_transform(processed_image)
    except Exception as e:
        print(f"Error in _preprocess_pil_sketch_for_inference: {e}")
        return None


def perform_inference_on_pil_image(pil_image_L, source_description="Sketch"):
    """
    Performs inference on a preprocessed PIL sketch using all trained models.
    `pil_image_L` should be a grayscale PIL Image.
    """
    with inference_output_area:
        clear_output(wait=True) # Clear previous inference results
        print(f"Performing inference on: {source_description}\n")

        any_model_processed = False
        displayed_once = False # To display the processed image only once per call

        for model_name, config in MODELS_FOR_INFERENCE.items():
            print(f"--- {model_name} ---")
            classifier = config.get("classifier")
            feature_extractor = config.get("feature_extractor_instance") # Use the instantiated one
            transform = config.get("transform")

            if not classifier:
                print(f"Classifier for {model_name} not available. Skipping.")
                continue
            if not feature_extractor:
                print(f"Feature extractor for {model_name} not available. Skipping.")
                continue
            if not transform:
                print(f"Transform for {model_name} not available. Skipping.")
                continue

            any_model_processed = True

            # Process the PIL image using the model's specific transform
            input_tensor = _preprocess_pil_sketch_for_inference(pil_image_L, transform, display_image=not displayed_once)
            if input_tensor is None:
                print(f"Skipping {model_name} due to image processing error.")
                continue
            displayed_once = True # Image shown, don't show again for other models in this call

            input_tensor = input_tensor.unsqueeze(0).to(DEVICE)

            feature_extractor.eval() # Ensure it's in eval mode
            with torch.no_grad():
                features_np = feature_extractor(input_tensor).cpu().numpy()

            try:
                if hasattr(classifier, "predict_proba") and hasattr(classifier, "predict"):
                    prediction_idx = classifier.predict(features_np)[0]
                    predicted_proba = classifier.predict_proba(features_np)[0]

                    predicted_category = QUICKDRAW_CATEGORIES[prediction_idx]
                    confidence = predicted_proba[prediction_idx]

                    print(f"Predicted Category: {predicted_category}")
                    print(f"Confidence: {confidence:.4f}")
                    print("Probabilities per category:")
                    for i, cat_name in enumerate(QUICKDRAW_CATEGORIES):
                        print(f"  {cat_name}: {predicted_proba[i]:.4f}")
                else:
                    print("Classifier does not support predict_proba or predict.")
            except Exception as e:
                print(f"Error during prediction/probability calculation for {model_name}: {e}")
            print("-" * 20)

        if not any_model_processed:
            print("No models were ready for inference. Please ensure classifiers are trained/loaded and models prepared.")

## 7. Interactive Drawing Canvas for Real-Time Inference
Draw your sketch in the canvas below (black lines on a white background). Then click "Predict from Canvas".

In [None]:
# --- Interactive Canvas Setup ---
canvas = Canvas(width=CANVAS_SIZE, height=CANVAS_SIZE, layout=widgets.Layout(border='1px solid black'), sync_image_data=True)

# Initialize canvas background and drawing style
def initialize_drawing_canvas():
    with hold_canvas(canvas): # hold_canvas batches drawing commands
        canvas.fill_style = 'white'
        canvas.fill_rect(0, 0, CANVAS_SIZE, CANVAS_SIZE) # Fill background white
        canvas.stroke_style = 'black' # Drawing color
        canvas.line_width = QuickDrawBinaryDataset.LINE_WIDTH * 1.5 # Slightly thicker lines
        canvas.line_cap = 'round' # Smoother line endings
        canvas.line_join = 'round' # Smoother line connections

initialize_drawing_canvas() # Set up the canvas initially

is_drawing_on_canvas = False

def on_canvas_mouse_down(x, y):
    global is_drawing_on_canvas
    is_drawing_on_canvas = True
    with hold_canvas(canvas):
        canvas.begin_path() # Start a new line segment
        canvas.move_to(x,y)

def on_canvas_mouse_move(x, y):
    if is_drawing_on_canvas:
        with hold_canvas(canvas):
            canvas.line_to(x, y)
            canvas.stroke() # Draw the segment

def on_canvas_mouse_up(x,y):
    global is_drawing_on_canvas
    if is_drawing_on_canvas: # Finalize the line segment
        with hold_canvas(canvas):
            canvas.line_to(x,y)
            canvas.stroke()
    is_drawing_on_canvas = False

def on_canvas_mouse_out(x,y): # If mouse leaves canvas while drawing
    global is_drawing_on_canvas
    is_drawing_on_canvas = False # Stop drawing

canvas.on_mouse_down(on_canvas_mouse_down)
canvas.on_mouse_move(on_canvas_mouse_move)
canvas.on_mouse_up(on_canvas_mouse_up)
canvas.on_mouse_out(on_canvas_mouse_out)


# --- Buttons for Canvas ---
predict_canvas_button = widgets.Button(description="Predict from Canvas")
clear_canvas_button = widgets.Button(description="Clear Canvas")

def predict_from_canvas_action(button_event):
    try:
        # Get image data from canvas. np.asarray converts ImageData to numpy array.
        image_array_rgba = np.asarray(canvas.get_image_data()) # Expected shape (height, width, 4)

        if image_array_rgba is None or image_array_rgba.size == 0:
            with inference_output_area:
                clear_output(wait=True)
                print("Canvas is empty or image data could not be retrieved.")
            return

        pil_image_rgba = Image.fromarray(image_array_rgba, 'RGBA')
        pil_image_L = pil_image_rgba.convert('L') # Convert to Grayscale

        perform_inference_on_pil_image(pil_image_L, source_description="Canvas Drawing")
    except Exception as e:
        with inference_output_area:
            clear_output(wait=True)
            print(f"Error getting or processing image from canvas: {e}")
            print("Please ensure you have drawn something on the canvas.")
            import traceback
            traceback.print_exc()


def clear_canvas_action(button_event):
    initialize_drawing_canvas() # Re-initialize to clear and reset styles
    with inference_output_area: # Clear previous predictions as well
        clear_output()


predict_canvas_button.on_click(predict_from_canvas_action)
clear_canvas_button.on_click(clear_canvas_action)

# Display canvas and buttons
canvas_controls = widgets.HBox([predict_canvas_button, clear_canvas_button])
print("Draw your sketch below and click 'Predict from Canvas'.")
display(canvas)
display(canvas_controls)
display(inference_output_area) # Shared output area for predictions

Draw your sketch below and click 'Predict from Canvas'.


Canvas(height=256, layout=Layout(border='1px solid black'), sync_image_data=True, width=256)

HBox(children=(Button(description='Predict from Canvas', style=ButtonStyle()), Button(description='Clear Canva…

Output()

## 7.5 Benchmarking Classifier Performance
This section evaluates the performance of the trained classifiers on a separate test set
created from the downloaded QuickDraw data.


In [None]:
def create_benchmark_dataloader(categories, root_dir, transform, num_samples_per_category, offset_per_category, batch_size, num_workers=0):
    """
    Creates a DataLoader for benchmarking using unseen data.
    """
    benchmark_datasets = []
    print(f"\nLoading benchmark data (offsetting by {offset_per_category} samples per category)...")
    for i, category_name_str in enumerate(categories):
        dataset = QuickDrawBinaryDataset(
            root_dir=root_dir,
            category_name=category_name_str,
            transform=transform,
            max_items=num_samples_per_category,
            category_label=i,
            offset=offset_per_category # Key change: use offset to get different data
        )
        if len(dataset) > 0:
            benchmark_datasets.append(dataset)
        else:
            print(f"Warning: No benchmark data loaded for category '{category_name_str}' with offset.")

    if not benchmark_datasets:
        print("No benchmark datasets could be created. Aborting benchmark.")
        return None

    combined_benchmark_dataset = ConcatDataset(benchmark_datasets)
    print(f"Total samples for benchmarking: {len(combined_benchmark_dataset)}")

    if len(combined_benchmark_dataset) == 0:
        print("Combined benchmark dataset is empty. Aborting benchmark.")
        return None

    return DataLoader(combined_benchmark_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

def run_benchmark():
    global MODELS_FOR_INFERENCE
    global QUICKDRAW_CATEGORIES
    global BINARY_DATA_ROOT
    global NUM_SAMPLES_PER_CATEGORY_FOR_BENCHMARK
    global NUM_SAMPLES_PER_CATEGORY_FOR_CLASSIFIER # This is our offset
    global BATCH_SIZE
    global DEVICE

    print("\n--- Starting Benchmark ---")

    any_model_benchmarked = False
    for model_name, config in MODELS_FOR_INFERENCE.items():
        print(f"\n--- Benchmarking: {model_name} ---")

        classifier = config.get("classifier")
        feature_extractor = config.get("feature_extractor_instance")
        transform = config.get("transform")

        if not classifier or not feature_extractor or not transform:
            print(f"Model {model_name} or its components not ready. Skipping benchmark.")
            continue

        # Create benchmark DataLoader using the model's transform
        # The offset ensures we use data not seen by the classifier training
        benchmark_loader = create_benchmark_dataloader(
            categories=QUICKDRAW_CATEGORIES,
            root_dir=BINARY_DATA_ROOT,
            transform=transform,
            num_samples_per_category=NUM_SAMPLES_PER_CATEGORY_FOR_BENCHMARK,
            offset_per_category=NUM_SAMPLES_PER_CATEGORY_FOR_CLASSIFIER,
            batch_size=BATCH_SIZE
        )

        if not benchmark_loader:
            print(f"Could not create benchmark loader for {model_name}. Skipping.")
            continue

        if len(benchmark_loader.dataset) == 0:
            print(f"Benchmark dataset for {model_name} is empty. Skipping.")
            continue

        all_preds = []
        all_labels = []

        feature_extractor.eval()
        feature_extractor.to(DEVICE)

        print(f"Extracting features and predicting for {model_name} on benchmark set...")
        for inputs, labels in tqdm(benchmark_loader, desc=f"Benchmarking {model_name}", leave=False):
            inputs = inputs.to(DEVICE)
            with torch.no_grad():
                features_np = feature_extractor(inputs).cpu().numpy()

            try:
                preds = classifier.predict(features_np)
                all_preds.extend(preds)
                all_labels.extend(labels.cpu().numpy())
            except Exception as e:
                print(f"Error during prediction for benchmark batch with {model_name}: {e}")
                # Decide how to handle: skip batch, or fill with dummy predictions? For now, continue.
                continue

        if not all_labels or not all_preds:
            print(f"No predictions or labels collected for {model_name}. Skipping metrics.")
            continue

        accuracy = accuracy_score(all_labels, all_preds)
        print(f"\nBenchmark Accuracy for {model_name}: {accuracy:.4f}")

        # Ensure target_names matches the actual labels present if using subset
        # For simplicity, using all QUICKDRAW_CATEGORIES, assuming all were loaded.
        # If some categories had no data, this might need adjustment or error handling.
        try:
            unique_labels_in_data = sorted(list(set(all_labels)))
            target_names_for_report = [QUICKDRAW_CATEGORIES[i] for i in unique_labels_in_data]

            print("\nClassification Report:")
            # Set zero_division=0 to handle cases where a class might not have predictions in small test sets
            report = classification_report(all_labels, all_preds, target_names=target_names_for_report, zero_division=0)
            print(report)
        except Exception as e:
            print(f"Could not generate classification report for {model_name}: {e}")
            print("This might happen if some classes had no test samples or no predictions.")

        any_model_benchmarked = True
        print("-" * 30)

    if not any_model_benchmarked:
        print("No models were benchmarked. Check configurations and data.")
    print("--- Benchmark Finished ---")

# Run the benchmark after models are prepared and classifiers trained,
# and before the interactive/upload inference sections.
# This should be called after `prepare_models_and_train_classifiers()` has run.
# For safety, check if models are ready.
models_are_ready = all(
    mc.get("classifier") and mc.get("feature_extractor_instance") and mc.get("transform")
    for mc in MODELS_FOR_INFERENCE.values()
)

if models_are_ready:
    run_benchmark()
else:
    print("Models are not fully prepared. Skipping benchmark. Please run 'prepare_models_and_train_classifiers' first.")



--- Starting Benchmark ---

--- Benchmarking: SqueezeNet1_1 ---

Loading benchmark data (offsetting by 100 samples per category)...
Total samples for benchmarking: 108
Extracting features and predicting for SqueezeNet1_1 on benchmark set...


Benchmarking SqueezeNet1_1:   0%|          | 0/4 [00:00<?, ?it/s]


Benchmark Accuracy for SqueezeNet1_1: 0.7407

Classification Report:
Could not generate classification report for SqueezeNet1_1: name 'classification_report' is not defined
This might happen if some classes had no test samples or no predictions.
------------------------------

--- Benchmarking: MobileNetV3-Small ---

Loading benchmark data (offsetting by 100 samples per category)...
Total samples for benchmarking: 108
Extracting features and predicting for MobileNetV3-Small on benchmark set...


Benchmarking MobileNetV3-Small:   0%|          | 0/4 [00:00<?, ?it/s]


Benchmark Accuracy for MobileNetV3-Small: 0.8148

Classification Report:
Could not generate classification report for MobileNetV3-Small: name 'classification_report' is not defined
This might happen if some classes had no test samples or no predictions.
------------------------------
--- Benchmark Finished ---


## 8. (Optional) Inference on Uploaded Sketch
You can also upload a pre-drawn sketch (PNG, JPG, GIF - black lines on white background preferred).

In [None]:
# --- Image Upload Widget ---
uploader = widgets.FileUpload(
    accept='.png,.jpg,.jpeg,.gif',
    multiple=False,
    description='Upload Sketch File'
)

_last_uploaded_file_name = None # To help with re-upload logic if needed

def handle_file_upload(change):
    global _last_uploaded_file_name

    if not uploader.value: # No file selected or selection cleared
        _last_uploaded_file_name = None
        return

    uploaded_file_info = uploader.value[0] # Get the first (and only) file's info

    if uploaded_file_info['name'] == _last_uploaded_file_name and not FORCE_RETRAIN_CLASSIFIERS:
         pass # Allow processing same file name

    _last_uploaded_file_name = uploaded_file_info['name']

    file_content = uploaded_file_info['content']
    file_name = uploaded_file_info['name']

    try:
        pil_image = Image.open(io.BytesIO(file_content)) # Open image
        if pil_image.mode != 'L':
            pil_image_L = pil_image.convert("L") # Convert to Grayscale if not already
        else:
            pil_image_L = pil_image
        perform_inference_on_pil_image(pil_image_L, source_description=f"Uploaded Sketch: {file_name}")
    except UnidentifiedImageError:
        with inference_output_area:
            clear_output(wait=True)
            print(f"Error: Cannot identify image file '{file_name}'. Please upload a valid image (PNG, JPG, GIF).")
    except Exception as e:
        with inference_output_area:
            clear_output(wait=True)
            print(f"Error opening or processing uploaded image '{file_name}': {e}")

    # Attempt to reset the uploader to allow re-uploading the same file.
    uploader.value.clear()
    uploader._counter = 0

uploader.observe(handle_file_upload, names='value')

print("\nOr, upload a sketch file:")
display(uploader)
# The inference_output_area is already displayed with the canvas section.


Or, upload a sketch file:


FileUpload(value={}, accept='.png,.jpg,.jpeg,.gif', description='Upload Sketch File')