In [None]:
# Import the os module to interact with the operating system
import os

# 1) Disable MSMF (Media Foundation) for OpenCV video capture, which can be unstable on Windows.
#    This switches OpenCV to use DirectShow instead, which is generally more reliable.
os.environ["OPENCV_VIDEOIO_PRIORITY_MSMF"] = "0"

# 2) Set OpenCV's log level to 'ERROR' to suppress verbose logging output.
os.environ["OPENCV_LOG_LEVEL"] = "ERROR"

In [None]:
# Import required modules for file operations and HTTP requests
import os, requests
from pathlib import Path

# Set the destination directory for downloaded .ndjson files
dst = Path("quickdraw_ndjson")
dst.mkdir(exist_ok=True)  # Create the directory if it doesn't exist

# List of class names to download from the Quick, Draw! dataset
CLASSES = [
    "aircraft carrier", "airplane", "alarm clock", "ambulance", "angel",
    "animal migration", "ant", "anvil", "apple", "arm", "asparagus", "axe",
    "backpack", "banana", "bandage", "barn", "baseball", "baseball bat",
    "basket", "basketball", "bat", "bathtub", "beach", "bear", "beard", "bed",
    "bee", "belt", "bench", "bicycle", "binoculars", "bird", "birthday cake",
    "blackberry", "blueberry", "book", "boomerang", "bottlecap", "bowtie",
    "bracelet", "brain", "bread", "bridge", "broccoli", "broom", "bucket",
    "bulldozer", "bus", "bush", "butterfly", "cactus", "cake", "calculator",
    "calendar", "camel", "camera", "camouflage", "campfire", "candle", "cannon",
    "canoe", "car", "carrot", "castle", "cat", "ceiling fan", "cello",
    "cell phone", "chair", "chandelier", "church", "circle", "clarinet", "clock",
    "cloud", "coffee cup", "compass", "computer", "cookie", "cooler", "couch",
    "cow", "crab", "crayon", "crocodile", "crown", "cruise ship", "cup",
    "diamond", "dishwasher", "diving board", "dog", "dolphin", "donut", "door",
    "dragon", "dresser", "drill", "drums", "duck", "dumbbell", "ear", "elbow",
    "elephant", "envelope", "eraser", "eye", "eyeglasses", "face", "fan",
    "feather", "fence", "finger", "fire hydrant", "fireplace", "firetruck",
    "fish", "flamingo", "flashlight", "flip flops", "floor lamp", "flower",
    "flying saucer", "foot", "fork", "frog", "frying pan", "garden", "garden hose",
    "giraffe", "goatee", "golf club", "grapes", "grass", "guitar", "hamburger",
    "hammer", "hand", "harp", "hat", "headphones", "hedgehog", "helicopter",
    "helmet", "hexagon", "hockey puck", "hockey stick", "horse", "hospital",
    "hot air balloon", "hot dog", "hot tub", "hourglass", "house", "house plant",
    "hurricane", "ice cream", "jacket", "jail", "kangaroo", "key", "keyboard",
    "knee", "knife", "ladder", "lantern", "laptop", "leaf", "leg", "light bulb",
    "lighter", "lighthouse", "lightning", "line", "lion", "lipstick", "lobster",
    "lollipop", "mailbox", "map", "marker", "matches", "megaphone", "mermaid",
    "microphone", "microwave", "monkey", "moon", "mosquito", "motorbike",
    "mountain", "mouse", "moustache", "mouth", "mug", "mushroom", "nail",
    "necklace", "nose", "ocean", "octagon", "octopus", "onion", "oven", "owl",
    "paintbrush", "paint can", "palm tree", "panda", "pants", "paper clip",
    "parachute", "parrot", "passport", "peanut", "pear", "peas", "pencil",
    "penguin", "piano", "pickup truck", "picture frame", "pig", "pillow",
    "pineapple", "pizza", "pliers", "police car", "pond", "pool", "popsicle",
    "postcard", "potato", "power outlet", "purse", "rabbit", "raccoon", "radio",
    "rain", "rainbow", "rake", "remote control", "rhinoceros", "rifle", "river",
    "roller coaster", "rollerskates", "sailboat", "sandwich", "saw", "saxophone",
    "school bus", "scissors", "scorpion", "screwdriver", "sea turtle", "see saw",
    "shark", "sheep", "shoe", "shorts", "shovel", "sink", "skateboard", "skull",
    "skyscraper", "sleeping bag", "smiley face", "snail", "snake", "snorkel",
    "snowflake", "snowman", "soccer ball", "sock", "speedboat", "spider", "spoon",
    "spreadsheet", "square", "squiggle", "squirrel", "stairs", "star", "steak",
    "stereo", "stethoscope", "stitches", "stop sign", "stove", "strawberry",
    "streetlight", "string bean", "submarine", "suitcase", "sun", "swan",
    "sweater", "swing set", "sword", "syringe", "table", "teapot", "teddy-bear",
    "telephone", "television", "tennis racquet", "tent", "The Eiffel Tower",
    "The Great Wall of China", "The Mona Lisa", "tiger", "toaster", "toe",
    "toilet", "tooth", "toothbrush", "toothpaste", "tornado", "tractor",
    "traffic light", "train", "tree", "triangle", "trombone", "truck", "trumpet",
    "t-shirt", "umbrella", "underwear", "van", "vase", "violin",
    "washing machine", "watermelon", "waterslide", "whale", "wheel", "windmill",
    "wine bottle", "wine glass", "wristwatch", "yoga", "zebra", "zigzag"
 ]          # class names map 1:1 to filenames

# Base URL for downloading the Quick, Draw! .ndjson files
BASE = "https://storage.googleapis.com/quickdraw_dataset/full/simplified"

# Download each class's .ndjson file if it doesn't already exist
for c in CLASSES:
    # Construct the download URL for the current class (spaces replaced with %20)
    url = f"{BASE}/{c.replace(' ', '%20')}.ndjson"
    out = dst / f"{c}.ndjson"
    if out.exists():
        print(f"Skip (exists): {out}")  # Skip download if file already exists
        continue
    print(f"Downloading {c} ...")
    r = requests.get(url, stream=True, timeout=60)  # Download with streaming
    r.raise_for_status()  # Raise an error if the download fails
    with open(out, "wb") as f:
        for chunk in r.iter_content(1 << 20):  # Write in 1MB chunks
            if chunk:
                f.write(chunk)
print("Done.")  # All downloads complete

In [None]:
# convert_qd_all_to_cls.py
# -------------------------------------------------------------
# Convert Google Quick, Draw! .ndjson stroke files into a
# YOLO CLASSIFICATION dataset (train/val/test per class).
# -------------------------------------------------------------

# Import required modules
import json, random
from pathlib import Path
from PIL import Image, ImageDraw, ImageFilter

# --------- Setup --------------------------------
# Directory containing downloaded .ndjson files
SRC_DIR        = Path("quickdraw_ndjson")    # where all *.ndjson live
# Output directory for the YOLO classification dataset
OUT_DIR        = Path("quickdraw_cls_all")   # where the image dataset will be written
# Output image size (height = width = CANVAS). Should match YOLO imgsz
CANVAS         = 256                         # output image size (H=W=CANVAS). Keep in sync with YOLO imgsz
# Stroke thickness for drawing. 6–10 is a good range for visibility
LINE_WIDTH     = 8                           # stroke thickness; 6–10 is a good range
# Maximum number of images per class to keep the dataset manageable
MAX_PER_CLASS  = 2000                        # cap per class to keep dataset tractable. Raise later if you have GPU/disk
# Split ratios for train/val/test sets
SPLIT          = (0.8, 0.1, 0.1)             # train/val/test ratios
# JPEG quality (tradeoff: file size vs fidelity)
JPEG_QUALITY   = 90                          # JPEG quality (tradeoff: file size vs fidelity)
# Amount of Gaussian blur for anti-aliasing (0.0 = none)
SMOOTH_BLUR    = 0.0                         # 0.0 = none; try 0.5 for gentle anti-aliasing

# --------- Functions ------------------------------------------
def main():
    # Fix randomness so splits are repeatable
    random.seed(0)

    # Discover all classes and their .ndjson files
    classes, files = discover_classes(SRC_DIR)
    print(f"Discovered {len(classes)} classes.")
    ensure_dirs(classes)

    total = 0
    # Process each class and convert sketches to images
    for cls_name, fpath in zip(classes, files):
        print(f"Processing {cls_name} ...")
        n = process_one_class(cls_name, fpath)
        print(f"  wrote {n} images")
        total += n

    print(f"\nAll done! total images = {total}")
    print(f"Dataset root: {OUT_DIR.resolve()}")

if __name__ == "__main__":
    main()
    
def discover_classes(src_dir: Path):
    """
    Find all .ndjson files and return:
      - classes: [class_name, ...]   (stems of filenames)
      - files:   [Path_to_ndjson, ...]
    Order is alphabetical for reproducibility.
    """
    files = sorted(src_dir.glob("*.ndjson"))
    if not files:
        raise FileNotFoundError(f"No .ndjson found in {src_dir.resolve()}")
    classes = [p.stem for p in files]
    return classes, files

def ensure_dirs(classes):
    """
    Create the YOLO classification directory structure:
      OUT_DIR/{train,val,test}/{class}/
    """
    for split in ("train", "val", "test"):
        for c in classes:
            (OUT_DIR / split / c).mkdir(parents=True, exist_ok=True)

def draw_example(drawing):
    """
    Rasterize a single Quick, Draw! sample to a CANVAS×CANVAS grayscale image.
    - drawing: list of strokes, each stroke = [xs[], ys[]]
    Returns a PIL.Image (mode 'L') or None if no points were drawn.
    """
    # Create a blank grayscale image (black background)
    img = Image.new("L", (CANVAS, CANVAS), 0)         # black background (0)
    drw = ImageDraw.Draw(img)

    drew_anything = False
    for stroke in drawing:
        xs, ys = stroke[0], stroke[1]
        # Convert parallel arrays to list of (x,y) points
        pts = list(zip(xs, ys))
        if len(pts) > 1:
            # Draw polyline with thicker pen for visibility
            drw.line(pts, fill=255, width=LINE_WIDTH, joint="curve")
            drew_anything = True
        elif len(pts) == 1:
            # Single-tap stroke: draw a small filled dot
            x, y = pts[0]
            r = max(1, LINE_WIDTH // 2)
            drw.ellipse((x - r, y - r, x + r, y + r), fill=255)
            drew_anything = True

    # Optional gentle blur to anti-alias jaggies (keeps edges smooth)
    if drew_anything and SMOOTH_BLUR > 0:
        img = img.filter(ImageFilter.GaussianBlur(SMOOTH_BLUR))
    return img if drew_anything else None

def process_one_class(cls_name: str, ndjson_path: Path) -> int:
    """
    Read sketches from one .ndjson file, draw up to MAX_PER_CLASS,
    shuffle, split into train/val/test, and save as JPEG.
    Returns the number of images written for this class.
    """
    samples = []
    with open(ndjson_path, "r", encoding="utf-8") as f:
        for line in f:
            d = json.loads(line)
            img = draw_example(d["drawing"])
            if img is None:
                continue
            samples.append(img)
            if MAX_PER_CLASS and len(samples) >= MAX_PER_CLASS:
                break

    if not samples:
        print(f"  WARNING: no valid samples for class '{cls_name}'")
        return 0

    # Deterministic shuffle/split for reproducibility
    random.shuffle(samples)
    n = len(samples)
    n_train = int(n * SPLIT[0])
    n_val   = int(n * SPLIT[1])

    splits = [
        ("train", samples[:n_train]),
        ("val",   samples[n_train:n_train + n_val]),
        ("test",  samples[n_train + n_val:]),
    ]

    # Save out as JPEG (smaller/faster than PNG). Convert to RGB first.
    written = 0
    for split_name, imgs in splits:
        out_dir = OUT_DIR / split_name / cls_name
        for i, im in enumerate(imgs):
            im = im.convert("RGB")
            im.save(out_dir / f"{cls_name}_{i:07d}.jpg", format="JPEG", quality=JPEG_QUALITY)
            written += 1
    return written

In [None]:
# Import the YOLO model from the ultralytics package
from ultralytics import YOLO

# Set the path to the dataset directory created by the converter script
DATA_DIR = r"C:/Users/doubl/quickdraw_cls_all"

# Load a pre-trained YOLO classification model
model = YOLO("yolo11m-cls.pt")   

"""
Train the YOLO classification model on the Quick, Draw! dataset.
- data: Path to the dataset directory (should contain train/val/test folders)
- epochs: Number of training epochs
- imgsz: Input image size (should match the size used in dataset creation)
- batch: Batch size for training
- workers: Number of data loading workers
- lr0: Initial learning rate (lower if training is unstable)
- project: Directory to save training runs
- name: Name for this training run
- exist_ok: Overwrite existing run directory if it exists
"""
model.train(
    data= str(DATA_DIR),  # points to the folder created by the converter
    epochs=20,                      
    imgsz=256,                      
    batch=64,                       
    workers=8,                      
    lr0=0.001,                      # lower (e.g., 8e-4) if unstable
    project="runs/classify",
    name="qd_all_v1",
    exist_ok=True
)

In [None]:
# Import Path from pathlib for file searching
from pathlib import Path

# Search for all 'best.pt' model weights in the training output directory
for p in Path("runs/classify").rglob("best.pt"):
    print(p)  # Print the path to each found weights file

runs\classify\train\weights\best.pt


In [None]:
# Import the YOLO model from ultralytics for inference and export
from ultralytics import YOLO

# Export section: convert the best trained model to ONNX format for deployment
from pathlib import Path

# Find and print all 'best.pt' weights files in the training output directory
for p in Path("runs/classify").rglob("best.pt"):
    print(p)

# Load the best trained model weights
best = YOLO("runs/classify/qd_all_v1/weights/best.pt")
# Export the model to ONNX format with dynamic input size and opset 12
best.export(format="onnx", opset=12, dynamic=True)

# Quick live demo: run inference on webcam (labels the whole frame)
YOLO("runs/classify/qd_all_v1/weights/best.onnx").predict(
    source=0, show=True, imgsz=256, vid_stride=1  # increase stride to 2–3 for higher FPS on CPU
)