## Setup

In [23]:
from tqdm.notebook import tqdm # Import tqdm
import requests # Added for downloading
import urllib.parse # Added for URL encoding category names
import torch
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
import time
import struct # For unpacking binary data
from struct import unpack
import os
import collections
import logging
from datetime import datetime
import csv
import json
import gzip
import pickle
import random
from torch import nn
import torch.nn.functional as F

# import urllib.request # Removed: No longer downloading
from PIL import Image, ImageDraw
try:
    import git
    GIT_AVAILABLE = True
except ImportError:
    GIT_AVAILABLE = False
    print("gitpython not installed. Using default versioning.")


### Google Authentication for Multimodal Embeddings

In [2]:
from google.colab import auth
from google.cloud import aiplatform
# Corrected import: Image and MultiModalEmbeddingModel are directly under google.cloud.aiplatform.types
from vertexai.vision_models import Image as AiPlatformImage, MultiModalEmbeddingModel

# --- ⚠️ ACTION REQUIRED ⚠️ ---
# Authenticate your Google Cloud account.
# This will open a pop-up window for you to log in.
auth.authenticate_user()

# --- ⚠️ ACTION REQUIRED ⚠️ ---
# Please enter your Google Cloud Project ID and the region.
PROJECT_ID = "gen-lang-client-0897829271"  # @param {type:"string"}
LOCATION = "asia-south1"        # @param {type:"string"}

# Initialize the Vertex AI SDK
aiplatform.init(project=PROJECT_ID, location=LOCATION)

print("✅ Setup and authentication complete.")

✅ Setup and authentication complete.


In [3]:
# Remove stale data
!rm -rf results sample_ensemble_predictions data models
!mkdir results data models

In [18]:
# Configuration
QUICKDRAW_CATEGORIES = [
    'apple', 'cat', 'dog',
    'door', 'elephant', 'fish', 'flower', 'table',
    'grass', 'house', 'ice cream', 'circle', 'key', 'lion', 'moon', 'nose',
    'pencil', 'rabbit', 'sun', 'tree', 'umbrella', 'van', 'cake', 'airplane',
    'ant', 'banana', 'bed', 'bee', 'bicycle', 'bird', 'book', 'bread', 'bus',
    'elbow', 'ear', 'camera', 'car', 'chair', 'clock', 'cloud', 'hand',
    'computer', 'cookie', 'cow', 'crayon', 'cup', 'eraser', 'carrot', 'drums',
    'eye', 'knife'
]

QUICKDRAW_CATEGORIES_DETAILED = [
    {"name": "apple", "category": "Fruit", "id": "001"},
    {"name": "cat", "category": "Animal", "id": "002"},
    {"name": "dog", "category": "Animal", "id": "003"},
    {"name": "door", "category": "Building Part", "id": "004"},
    {"name": "elephant", "category": "Animal", "id": "005"},
    {"name": "fish", "category": "Animal", "id": "006"},
    {"name": "flower", "category": "Plant", "id": "007"},
    {"name": "table", "category": "Furniture", "id": "008"},
    {"name": "grass", "category": "Plant", "id": "009"},
    {"name": "house", "category": "Building", "id": "010"},
    {"name": "ice cream", "category": "Food", "id": "011"},
    {"name": "circle", "category": "Shape", "id": "012"},
    {"name": "key", "category": "Tool", "id": "013"},
    {"name": "lion", "category": "Animal", "id": "014"},
    {"name": "moon", "category": "Celestial Body", "id": "015"},
    {"name": "nose", "category": "Body Part", "id": "016"},
    {"name": "pencil", "category": "Writing Utensil", "id": "017"},
    {"name": "rabbit", "category": "Animal", "id": "018"},
    {"name": "sun", "category": "Celestial Body", "id": "019"},
    {"name": "tree", "category": "Plant", "id": "020"},
    {"name": "umbrella", "category": "Accessory", "id": "021"},
    {"name": "van", "category": "Vehicle", "id": "022"},
    {"name": "cake", "category": "Food", "id": "023"},
    {"name": "airplane", "category": "Vehicle", "id": "024"},
    {"name": "ant", "category": "Animal", "id": "025"},
    {"name": "banana", "category": "Fruit", "id": "026"},
    {"name": "bed", "category": "Furniture", "id": "027"},
    {"name": "bee", "category": "Animal", "id": "028"},
    {"name": "bicycle", "category": "Vehicle", "id": "029"},
    {"name": "bird", "category": "Animal", "id": "030"},
    {"name": "book", "category": "Object", "id": "031"},
    {"name": "bread", "category": "Food", "id": "032"},
    {"name": "bus", "category": "Vehicle", "id": "033"},
    {"name": "elbow", "category": "Body Part", "id": "034"},
    {"name": "ear", "category": "Body Part", "id": "035"},
    {"name": "camera", "category": "Device", "id": "036"},
    {"name": "car", "category": "Vehicle", "id": "037"},
    {"name": "chair", "category": "Furniture", "id": "038"},
    {"name": "clock", "category": "Object", "id": "039"},
    {"name": "cloud", "category": "Weather", "id": "040"},
    {"name": "hand", "category": "Body Part", "id": "041"},
    {"name": "computer", "category": "Device", "id": "042"},
    {"name": "cookie", "category": "Food", "id": "043"},
    {"name": "cow", "category": "Animal", "id": "044"},
    {"name": "crayon", "category": "Writing Utensil", "id": "045"},
    {"name": "cup", "category": "Object", "id": "046"},
    {"name": "eraser", "category": "Writing Utensil", "id": "047"},
    {"name": "carrot", "category": "Vegetable", "id": "048"},
    {"name": "drums", "category": "Musical Instrument", "id": "049"},
    {"name": "eye", "category": "Body Part", "id": "050"},
    {"name": "knife", "category": "Tool", "id": "051"}
]

print(f"📚 Defined {len(QUICKDRAW_CATEGORIES_DETAILED)} objects to recognize.")

# NUM_TRAIN_SAMPLES_PER_CATEGORY = 8000
NUM_TRAIN_SAMPLES_PER_CATEGORY = 100
# NUM_TEST_SAMPLES_PER_CATEGORY = 2000
NUM_TEST_SAMPLES_PER_CATEGORY = 20
QUICKDRAW_CACHE_SIZE=50000  # Increased from 20000 to reduce cache misses with larger datasets
IMAGE_SIZE = (224, 224)
LINE_WIDTH = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BINARY_DATA_ROOT = './data' # MODIFIED: Point to local data directory

# Configuration section - add dynamic batch size scaling
BATCH_SIZE = 128  # Increase base batch size for smaller datasets
# Add adaptive batch size based on dataset size
MAX_BATCH_SIZE = 128  # Increased for ml.g4dn.xlarge (16GB GPU)
MIN_BATCH_SIZE = 16   # Reduced from 32 to prevent OOM errors with large datasets

# Add a function to calculate appropriate batch size
def get_adaptive_batch_size(num_samples_per_category, num_categories):
    """Calculate appropriate batch size based on dataset size to prevent OOM errors"""
    total_samples = num_samples_per_category * num_categories

    if total_samples < 10000:  # Small dataset
        return MAX_BATCH_SIZE
    elif total_samples < 50000:  # Medium dataset
        return 64  # Increased from 32
    elif total_samples < 200000:  # Large dataset
        return 32  # Increased from 16
    else:  # Very large dataset (10000 samples x 50+ categories)
        return MIN_BATCH_SIZE

# Fine-tuning hyperparameters - updated for better training
NUM_FINETUNE_EPOCHS = 20               # Set to 20 as specified
FINETUNE_LEARNING_RATE = 5e-4          # Slightly increased from 1e-4
FINETUNE_WEIGHT_DECAY = 1e-5           # Added weight decay for regularization
MODEL_SAVE_PATH = './models'           # Directory to save fine-tuned models
VALIDATION_SPLIT = 0.1                 # Percentage of training data to use for validation
USE_GRADUAL_UNFREEZING = True          # Whether to use gradual unfreezing
USE_DATA_AUGMENTATION = True           # Whether to use data augmentation
GRADIENT_ACCUMULATION_STEPS = 1        # Default: update weights after every batch
USE_GRADIENT_CHECKPOINTING = True      # Enable gradient checkpointing to save memory
CHECKPOINT_INTERVAL = 5                # Save checkpoints every N epochs
RESUME_FROM_CHECKPOINT = True          # Whether to resume from checkpoint if available


# --- Model Definitions and Feature Extractors (Unchanged) ---
MODELS_TO_TEST = {
    "MobileNetV3-Small": {
        "weights": models.MobileNet_V3_Small_Weights.IMAGENET1K_V1,
        "model_fn": models.mobilenet_v3_small,
    },
    "ShuffleNetV2_x0_5": {
        "weights": models.ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
        "model_fn": models.shufflenet_v2_x0_5,
    },
    "SqueezeNet1_1": {
        "weights": models.SqueezeNet1_1_Weights.IMAGENET1K_V1,
        "model_fn": models.squeezenet1_1,
    },
    # "EfficientNet-B0": {
    #     "weights": models.EfficientNet_B0_Weights.IMAGENET1K_V1,
    #     "model_fn": models.efficientnet_b0,
    #
    # }
}

# --- Logging Configuration ---
# Set to logging.DEBUG for verbose development output, logging.INFO for less
LOG_LEVEL = logging.INFO

# Create a logs directory if it doesn't exist
LOGS_DIR = './logs'
if not os.path.exists(LOGS_DIR):
    os.makedirs(LOGS_DIR, exist_ok=True)

# Generate a log file name based on the notebook file name, datetime, and environment
notebook_name = 'quickdraw_benchmark'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
log_file_name = f"{notebook_name}_{timestamp}_{device_type}.log"
log_file_path = os.path.join(LOGS_DIR, log_file_name)

# Get the root logger (or a specific logger, __name__ is fine)
# Avoid using basicConfig if you need fine-grained handler control after creation
# basicConfig configures the root logger, but we can get it and clear handlers
# or just create our own logger from scratch. Let's create our own logger explicitly.
logger = logging.getLogger(__name__)
logger.setLevel(LOG_LEVEL) # Set the level for *this* logger

# Prevent duplicate handlers if the cell is run multiple times
if logger.hasHandlers():
    logger.handlers.clear()

# Create handlers manually
# File Handler: Use buffering=1 for line buffering (most common for text) or 0 for no buffering
# For binary data like the .bin files are processed from, default buffering applies.
# However, the FileHandler *itself* writes text logs, so we can try line buffering.
# If that's not sufficient, we could force flush periodically.
try:
    # Using buffering=1 for line buffering in text mode ('w') is standard,
    # but FileHandler uses 'a' by default. Let's try 'a' with a smaller buffer if possible,
    # or just force flushing. Manual flushing is more reliable for immediate write.
    file_handler = logging.FileHandler(log_file_path, mode='w') # Use 'w' to overwrite each run or 'a' to append
    file_handler.setLevel(LOG_LEVEL)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(module)s - %(message)s')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Stream Handler for console output
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(LOG_LEVEL)
    stream_handler.setFormatter(formatter) # Use the same formatter
    logger.addHandler(stream_handler)

    logger.info(f"Logging initialized. Logs will be saved to {log_file_path}")

except Exception as e:
    # Fallback: logger.info an error if logging setup fails
    logger.info(f"Error setting up logging handlers: {e}", flush=True)
    # Revert to basicConfig just for console output if file logging failed
    logging.basicConfig(level=LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
    logger = logging.getLogger(__name__) # Get the root logger now configured by basicConfig
    logger.warning("File logging setup failed, falling back to console-only logging.")
    logger.info(f"Logging initialized (console only). Failed to save to {log_file_path}")




# Create a folder to download dataset if it does not exists
if not os.path.exists(BINARY_DATA_ROOT):
    os.makedirs(BINARY_DATA_ROOT, exist_ok=True)
    logger.warning(f"Warning: Data directory '{BINARY_DATA_ROOT}' was not found and has been created.")
    logger.info(f"Please ensure QuickDraw .bin files (e.g., full_binary_apple.bin) for categories {QUICKDRAW_CATEGORIES} are placed there.")

# Create directory for saving models if it doesn't exist
if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
    logger.info(f"Created directory for saving models: {MODEL_SAVE_PATH}")

def clear_gpu_memory():
    """Clear GPU cache to free memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def get_git_info():
    """Get the current git commit hash and message"""
    if not GIT_AVAILABLE:
        return {"available": False, "commit": "no_git", "message": "no_git"}

    try:
        repo = git.Repo(search_parent_directories=True)
        commit_hash = repo.head.commit.hexsha[:8]  # Short hash
        commit_message = repo.head.commit.message.strip().split('\n')[0]  # First line only
        # Replace spaces and special chars for filename safety
        safe_message = commit_message.replace(' ', '_').replace('/', '-').replace(':', '-')[:30]
        return {
            "available": True,
            "commit": commit_hash,
            "message": safe_message
        }
    except (git.InvalidGitRepositoryError, git.NoSuchPathError):
        return {"available": False, "commit": "no_git", "message": "no_git"}

# Create parameter string for filename
params_str = f"samples{NUM_TRAIN_SAMPLES_PER_CATEGORY}_epochs{NUM_FINETUNE_EPOCHS}_classes{len(QUICKDRAW_CATEGORIES)}"
git_info = get_git_info()
if git_info["available"]:
    git_str = f"_{git_info['message'][:10]}..."
else:
    git_str = ""

# Save latest epoch checkpoint (always overwriting previous)
LATEST_PATH = f'_{git_str}_{params_str}'


# Create embeddings dir and generate filename for the multimodal embeddings model
def get_embedding_filename(num_categories, num_images_per_category, base_dir="./embeddings"):
    """
    Generates a filename for the image embeddings file based on the number of categories
    and images per category. Ensures the directory exists.
    """
    # Create base directory if it doesn't exist
    os.makedirs(base_dir, exist_ok=True)
    filename = f"quickdraw_image_embeddings_categories{num_categories}_samples{num_images_per_category}.pkl"
    return os.path.join(base_dir, filename)



2025-09-01 10:49:46,292 - INFO - ipython-input-4102500602 - Logging initialized. Logs will be saved to ./logs/quickdraw_benchmark_20250901_104946_cuda.log
INFO:__main__:Logging initialized. Logs will be saved to ./logs/quickdraw_benchmark_20250901_104946_cuda.log


📚 Defined 51 objects to recognize.


# Generate text embeddings for the QUICKDRAW_CATEGORIES_DETAILED

# Load the multimodal embedding model
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
text_embeddings = []

# Generate an embedding for each object's name
for item in QUICKDRAW_CATEGORIES_DETAILED:
    # The context_text is optional but can improve quality for ambiguous terms
    embeddings = model.get_embeddings(
        contextual_text=item["name"],
        dimension=1408 # A required, fixed dimension size for this model
    )
    text_embeddings.append(embeddings.text_embedding)
    print(f"Generated embedding for: '{item['name']}'")

print(f"\n✅ Successfully generated {len(text_embeddings)} text embeddings.")
# An embedding is just a list of numbers
print(f"   Each embedding has {len(text_embeddings[0])} dimensions.")

In [7]:

# --- Part 1: Data Download and Preparation ---
# %%
def download_quickdraw_binary(category_name, download_dir):
    """
    Downloads the .bin file for a given QuickDraw category.
    Files are named 'full_binary_{category_name_underscored}.bin'.
    """
    # Sanitize category name for filename (replace spaces with underscores)
    filename_category_part = category_name.replace(' ', '_')
    local_filename = f"full_binary_{filename_category_part}.bin"
    local_filepath = os.path.join(download_dir, local_filename)

    if os.path.exists(local_filepath):
        logger.info(f"File for '{category_name}' already exists: {local_filepath}")
        return

    # URL encode category name for the download URL (e.g., "ice cream" -> "ice%20cream")
    url_category_part = urllib.parse.quote(category_name)
    url = f"https://storage.googleapis.com/quickdraw_dataset/full/binary/{url_category_part}.bin"

    logger.info(f"Downloading '{category_name}' from {url} to {local_filepath}...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise an exception for HTTP errors

        total_size = int(response.headers.get('content-length', 0))

        with open(local_filepath, 'wb') as f, tqdm(
            desc=category_name,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=8192):
                size = f.write(chunk)
                bar.update(size)
        logger.info(f"Successfully downloaded '{category_name}'.")
    except requests.exceptions.RequestException as e:
        logger.info(f"Error downloading '{category_name}': {e}")
        if os.path.exists(local_filepath): # Clean up partial download
            os.remove(local_filepath)
    except Exception as e:
        logger.info(f"An unexpected error occurred while downloading '{category_name}': {e}")
        if os.path.exists(local_filepath): # Clean up partial download
            os.remove(local_filepath)


logger.info(f"Starting download process for {len(QUICKDRAW_CATEGORIES)} categories into '{BINARY_DATA_ROOT}'...")
for category in QUICKDRAW_CATEGORIES:
    download_quickdraw_binary(category, BINARY_DATA_ROOT)
logger.info("Download process finished.")

# Explicitly flush handlers after a significant phase
for handler in logger.handlers:
    if isinstance(handler, logging.FileHandler):
        handler.flush()




2025-09-01 10:38:34,777 - INFO - ipython-input-2291996958 - Starting download process for 51 categories into './data'...
INFO:__main__:Starting download process for 51 categories into './data'...
2025-09-01 10:38:34,778 - INFO - ipython-input-2291996958 - Downloading 'apple' from https://storage.googleapis.com/quickdraw_dataset/full/binary/apple.bin to ./data/full_binary_apple.bin...
INFO:__main__:Downloading 'apple' from https://storage.googleapis.com/quickdraw_dataset/full/binary/apple.bin to ./data/full_binary_apple.bin...


apple:   0%|          | 0.00/13.2M [00:00<?, ?iB/s]

2025-09-01 10:38:35,330 - INFO - ipython-input-2291996958 - Successfully downloaded 'apple'.
INFO:__main__:Successfully downloaded 'apple'.
2025-09-01 10:38:35,333 - INFO - ipython-input-2291996958 - Downloading 'cat' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cat.bin to ./data/full_binary_cat.bin...
INFO:__main__:Downloading 'cat' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cat.bin to ./data/full_binary_cat.bin...


cat:   0%|          | 0.00/18.7M [00:00<?, ?iB/s]

2025-09-01 10:38:35,953 - INFO - ipython-input-2291996958 - Successfully downloaded 'cat'.
INFO:__main__:Successfully downloaded 'cat'.
2025-09-01 10:38:35,955 - INFO - ipython-input-2291996958 - Downloading 'dog' from https://storage.googleapis.com/quickdraw_dataset/full/binary/dog.bin to ./data/full_binary_dog.bin...
INFO:__main__:Downloading 'dog' from https://storage.googleapis.com/quickdraw_dataset/full/binary/dog.bin to ./data/full_binary_dog.bin...


dog:   0%|          | 0.00/22.4M [00:00<?, ?iB/s]

2025-09-01 10:38:36,659 - INFO - ipython-input-2291996958 - Successfully downloaded 'dog'.
INFO:__main__:Successfully downloaded 'dog'.
2025-09-01 10:38:36,661 - INFO - ipython-input-2291996958 - Downloading 'door' from https://storage.googleapis.com/quickdraw_dataset/full/binary/door.bin to ./data/full_binary_door.bin...
INFO:__main__:Downloading 'door' from https://storage.googleapis.com/quickdraw_dataset/full/binary/door.bin to ./data/full_binary_door.bin...


door:   0%|          | 0.00/8.15M [00:00<?, ?iB/s]

2025-09-01 10:38:36,999 - INFO - ipython-input-2291996958 - Successfully downloaded 'door'.
INFO:__main__:Successfully downloaded 'door'.
2025-09-01 10:38:37,001 - INFO - ipython-input-2291996958 - Downloading 'elephant' from https://storage.googleapis.com/quickdraw_dataset/full/binary/elephant.bin to ./data/full_binary_elephant.bin...
INFO:__main__:Downloading 'elephant' from https://storage.googleapis.com/quickdraw_dataset/full/binary/elephant.bin to ./data/full_binary_elephant.bin...


elephant:   0%|          | 0.00/17.9M [00:00<?, ?iB/s]

2025-09-01 10:38:37,691 - INFO - ipython-input-2291996958 - Successfully downloaded 'elephant'.
INFO:__main__:Successfully downloaded 'elephant'.
2025-09-01 10:38:37,693 - INFO - ipython-input-2291996958 - Downloading 'fish' from https://storage.googleapis.com/quickdraw_dataset/full/binary/fish.bin to ./data/full_binary_fish.bin...
INFO:__main__:Downloading 'fish' from https://storage.googleapis.com/quickdraw_dataset/full/binary/fish.bin to ./data/full_binary_fish.bin...


fish:   0%|          | 0.00/11.6M [00:00<?, ?iB/s]

2025-09-01 10:38:38,053 - INFO - ipython-input-2291996958 - Successfully downloaded 'fish'.
INFO:__main__:Successfully downloaded 'fish'.
2025-09-01 10:38:38,055 - INFO - ipython-input-2291996958 - Downloading 'flower' from https://storage.googleapis.com/quickdraw_dataset/full/binary/flower.bin to ./data/full_binary_flower.bin...
INFO:__main__:Downloading 'flower' from https://storage.googleapis.com/quickdraw_dataset/full/binary/flower.bin to ./data/full_binary_flower.bin...


flower:   0%|          | 0.00/20.5M [00:00<?, ?iB/s]

2025-09-01 10:38:38,770 - INFO - ipython-input-2291996958 - Successfully downloaded 'flower'.
INFO:__main__:Successfully downloaded 'flower'.
2025-09-01 10:38:38,776 - INFO - ipython-input-2291996958 - Downloading 'table' from https://storage.googleapis.com/quickdraw_dataset/full/binary/table.bin to ./data/full_binary_table.bin...
INFO:__main__:Downloading 'table' from https://storage.googleapis.com/quickdraw_dataset/full/binary/table.bin to ./data/full_binary_table.bin...


table:   0%|          | 0.00/9.30M [00:00<?, ?iB/s]

2025-09-01 10:38:39,210 - INFO - ipython-input-2291996958 - Successfully downloaded 'table'.
INFO:__main__:Successfully downloaded 'table'.
2025-09-01 10:38:39,212 - INFO - ipython-input-2291996958 - Downloading 'grass' from https://storage.googleapis.com/quickdraw_dataset/full/binary/grass.bin to ./data/full_binary_grass.bin...
INFO:__main__:Downloading 'grass' from https://storage.googleapis.com/quickdraw_dataset/full/binary/grass.bin to ./data/full_binary_grass.bin...


grass:   0%|          | 0.00/11.3M [00:00<?, ?iB/s]

2025-09-01 10:38:39,636 - INFO - ipython-input-2291996958 - Successfully downloaded 'grass'.
INFO:__main__:Successfully downloaded 'grass'.
2025-09-01 10:38:39,640 - INFO - ipython-input-2291996958 - Downloading 'house' from https://storage.googleapis.com/quickdraw_dataset/full/binary/house.bin to ./data/full_binary_house.bin...
INFO:__main__:Downloading 'house' from https://storage.googleapis.com/quickdraw_dataset/full/binary/house.bin to ./data/full_binary_house.bin...


house:   0%|          | 0.00/10.4M [00:00<?, ?iB/s]

2025-09-01 10:38:40,051 - INFO - ipython-input-2291996958 - Successfully downloaded 'house'.
INFO:__main__:Successfully downloaded 'house'.
2025-09-01 10:38:40,054 - INFO - ipython-input-2291996958 - Downloading 'ice cream' from https://storage.googleapis.com/quickdraw_dataset/full/binary/ice%20cream.bin to ./data/full_binary_ice_cream.bin...
INFO:__main__:Downloading 'ice cream' from https://storage.googleapis.com/quickdraw_dataset/full/binary/ice%20cream.bin to ./data/full_binary_ice_cream.bin...


ice cream:   0%|          | 0.00/11.3M [00:00<?, ?iB/s]

2025-09-01 10:38:40,646 - INFO - ipython-input-2291996958 - Successfully downloaded 'ice cream'.
INFO:__main__:Successfully downloaded 'ice cream'.
2025-09-01 10:38:40,648 - INFO - ipython-input-2291996958 - Downloading 'circle' from https://storage.googleapis.com/quickdraw_dataset/full/binary/circle.bin to ./data/full_binary_circle.bin...
INFO:__main__:Downloading 'circle' from https://storage.googleapis.com/quickdraw_dataset/full/binary/circle.bin to ./data/full_binary_circle.bin...


circle:   0%|          | 0.00/8.56M [00:00<?, ?iB/s]

2025-09-01 10:38:40,912 - INFO - ipython-input-2291996958 - Successfully downloaded 'circle'.
INFO:__main__:Successfully downloaded 'circle'.
2025-09-01 10:38:40,916 - INFO - ipython-input-2291996958 - Downloading 'key' from https://storage.googleapis.com/quickdraw_dataset/full/binary/key.bin to ./data/full_binary_key.bin...
INFO:__main__:Downloading 'key' from https://storage.googleapis.com/quickdraw_dataset/full/binary/key.bin to ./data/full_binary_key.bin...


key:   0%|          | 0.00/16.1M [00:00<?, ?iB/s]

2025-09-01 10:38:41,399 - INFO - ipython-input-2291996958 - Successfully downloaded 'key'.
INFO:__main__:Successfully downloaded 'key'.
2025-09-01 10:38:41,404 - INFO - ipython-input-2291996958 - Downloading 'lion' from https://storage.googleapis.com/quickdraw_dataset/full/binary/lion.bin to ./data/full_binary_lion.bin...
INFO:__main__:Downloading 'lion' from https://storage.googleapis.com/quickdraw_dataset/full/binary/lion.bin to ./data/full_binary_lion.bin...


lion:   0%|          | 0.00/23.5M [00:00<?, ?iB/s]

2025-09-01 10:38:41,998 - INFO - ipython-input-2291996958 - Successfully downloaded 'lion'.
INFO:__main__:Successfully downloaded 'lion'.
2025-09-01 10:38:42,004 - INFO - ipython-input-2291996958 - Downloading 'moon' from https://storage.googleapis.com/quickdraw_dataset/full/binary/moon.bin to ./data/full_binary_moon.bin...
INFO:__main__:Downloading 'moon' from https://storage.googleapis.com/quickdraw_dataset/full/binary/moon.bin to ./data/full_binary_moon.bin...


moon:   0%|          | 0.00/11.1M [00:00<?, ?iB/s]

2025-09-01 10:38:42,649 - INFO - ipython-input-2291996958 - Successfully downloaded 'moon'.
INFO:__main__:Successfully downloaded 'moon'.
2025-09-01 10:38:42,651 - INFO - ipython-input-2291996958 - Downloading 'nose' from https://storage.googleapis.com/quickdraw_dataset/full/binary/nose.bin to ./data/full_binary_nose.bin...
INFO:__main__:Downloading 'nose' from https://storage.googleapis.com/quickdraw_dataset/full/binary/nose.bin to ./data/full_binary_nose.bin...


nose:   0%|          | 0.00/13.2M [00:00<?, ?iB/s]

2025-09-01 10:38:43,276 - INFO - ipython-input-2291996958 - Successfully downloaded 'nose'.
INFO:__main__:Successfully downloaded 'nose'.
2025-09-01 10:38:43,278 - INFO - ipython-input-2291996958 - Downloading 'pencil' from https://storage.googleapis.com/quickdraw_dataset/full/binary/pencil.bin to ./data/full_binary_pencil.bin...
INFO:__main__:Downloading 'pencil' from https://storage.googleapis.com/quickdraw_dataset/full/binary/pencil.bin to ./data/full_binary_pencil.bin...


pencil:   0%|          | 0.00/9.62M [00:00<?, ?iB/s]

2025-09-01 10:38:43,766 - INFO - ipython-input-2291996958 - Successfully downloaded 'pencil'.
INFO:__main__:Successfully downloaded 'pencil'.
2025-09-01 10:38:43,771 - INFO - ipython-input-2291996958 - Downloading 'rabbit' from https://storage.googleapis.com/quickdraw_dataset/full/binary/rabbit.bin to ./data/full_binary_rabbit.bin...
INFO:__main__:Downloading 'rabbit' from https://storage.googleapis.com/quickdraw_dataset/full/binary/rabbit.bin to ./data/full_binary_rabbit.bin...


rabbit:   0%|          | 0.00/24.1M [00:00<?, ?iB/s]

2025-09-01 10:38:44,492 - INFO - ipython-input-2291996958 - Successfully downloaded 'rabbit'.
INFO:__main__:Successfully downloaded 'rabbit'.
2025-09-01 10:38:44,496 - INFO - ipython-input-2291996958 - Downloading 'sun' from https://storage.googleapis.com/quickdraw_dataset/full/binary/sun.bin to ./data/full_binary_sun.bin...
INFO:__main__:Downloading 'sun' from https://storage.googleapis.com/quickdraw_dataset/full/binary/sun.bin to ./data/full_binary_sun.bin...


sun:   0%|          | 0.00/15.1M [00:00<?, ?iB/s]

2025-09-01 10:38:45,004 - INFO - ipython-input-2291996958 - Successfully downloaded 'sun'.
INFO:__main__:Successfully downloaded 'sun'.
2025-09-01 10:38:45,007 - INFO - ipython-input-2291996958 - Downloading 'tree' from https://storage.googleapis.com/quickdraw_dataset/full/binary/tree.bin to ./data/full_binary_tree.bin...
INFO:__main__:Downloading 'tree' from https://storage.googleapis.com/quickdraw_dataset/full/binary/tree.bin to ./data/full_binary_tree.bin...


tree:   0%|          | 0.00/19.6M [00:00<?, ?iB/s]

2025-09-01 10:38:45,901 - INFO - ipython-input-2291996958 - Successfully downloaded 'tree'.
INFO:__main__:Successfully downloaded 'tree'.
2025-09-01 10:38:45,904 - INFO - ipython-input-2291996958 - Downloading 'umbrella' from https://storage.googleapis.com/quickdraw_dataset/full/binary/umbrella.bin to ./data/full_binary_umbrella.bin...
INFO:__main__:Downloading 'umbrella' from https://storage.googleapis.com/quickdraw_dataset/full/binary/umbrella.bin to ./data/full_binary_umbrella.bin...


umbrella:   0%|          | 0.00/11.0M [00:00<?, ?iB/s]

2025-09-01 10:38:46,509 - INFO - ipython-input-2291996958 - Successfully downloaded 'umbrella'.
INFO:__main__:Successfully downloaded 'umbrella'.
2025-09-01 10:38:46,514 - INFO - ipython-input-2291996958 - Downloading 'van' from https://storage.googleapis.com/quickdraw_dataset/full/binary/van.bin to ./data/full_binary_van.bin...
INFO:__main__:Downloading 'van' from https://storage.googleapis.com/quickdraw_dataset/full/binary/van.bin to ./data/full_binary_van.bin...


van:   0%|          | 0.00/20.7M [00:00<?, ?iB/s]

2025-09-01 10:38:46,826 - INFO - ipython-input-2291996958 - Successfully downloaded 'van'.
INFO:__main__:Successfully downloaded 'van'.
2025-09-01 10:38:46,830 - INFO - ipython-input-2291996958 - Downloading 'cake' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cake.bin to ./data/full_binary_cake.bin...
INFO:__main__:Downloading 'cake' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cake.bin to ./data/full_binary_cake.bin...


cake:   0%|          | 0.00/17.0M [00:00<?, ?iB/s]

2025-09-01 10:38:47,220 - INFO - ipython-input-2291996958 - Successfully downloaded 'cake'.
INFO:__main__:Successfully downloaded 'cake'.
2025-09-01 10:38:47,222 - INFO - ipython-input-2291996958 - Downloading 'airplane' from https://storage.googleapis.com/quickdraw_dataset/full/binary/airplane.bin to ./data/full_binary_airplane.bin...
INFO:__main__:Downloading 'airplane' from https://storage.googleapis.com/quickdraw_dataset/full/binary/airplane.bin to ./data/full_binary_airplane.bin...


airplane:   0%|          | 0.00/15.0M [00:00<?, ?iB/s]

2025-09-01 10:38:47,745 - INFO - ipython-input-2291996958 - Successfully downloaded 'airplane'.
INFO:__main__:Successfully downloaded 'airplane'.
2025-09-01 10:38:47,752 - INFO - ipython-input-2291996958 - Downloading 'ant' from https://storage.googleapis.com/quickdraw_dataset/full/binary/ant.bin to ./data/full_binary_ant.bin...
INFO:__main__:Downloading 'ant' from https://storage.googleapis.com/quickdraw_dataset/full/binary/ant.bin to ./data/full_binary_ant.bin...


ant:   0%|          | 0.00/17.7M [00:00<?, ?iB/s]

2025-09-01 10:38:48,231 - INFO - ipython-input-2291996958 - Successfully downloaded 'ant'.
INFO:__main__:Successfully downloaded 'ant'.
2025-09-01 10:38:48,236 - INFO - ipython-input-2291996958 - Downloading 'banana' from https://storage.googleapis.com/quickdraw_dataset/full/binary/banana.bin to ./data/full_binary_banana.bin...
INFO:__main__:Downloading 'banana' from https://storage.googleapis.com/quickdraw_dataset/full/binary/banana.bin to ./data/full_binary_banana.bin...


banana:   0%|          | 0.00/23.9M [00:00<?, ?iB/s]

2025-09-01 10:38:48,949 - INFO - ipython-input-2291996958 - Successfully downloaded 'banana'.
INFO:__main__:Successfully downloaded 'banana'.
2025-09-01 10:38:48,955 - INFO - ipython-input-2291996958 - Downloading 'bed' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bed.bin to ./data/full_binary_bed.bin...
INFO:__main__:Downloading 'bed' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bed.bin to ./data/full_binary_bed.bin...


bed:   0%|          | 0.00/10.5M [00:00<?, ?iB/s]

2025-09-01 10:38:49,376 - INFO - ipython-input-2291996958 - Successfully downloaded 'bed'.
INFO:__main__:Successfully downloaded 'bed'.
2025-09-01 10:38:49,379 - INFO - ipython-input-2291996958 - Downloading 'bee' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bee.bin to ./data/full_binary_bee.bin...
INFO:__main__:Downloading 'bee' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bee.bin to ./data/full_binary_bee.bin...


bee:   0%|          | 0.00/19.6M [00:00<?, ?iB/s]

2025-09-01 10:38:50,030 - INFO - ipython-input-2291996958 - Successfully downloaded 'bee'.
INFO:__main__:Successfully downloaded 'bee'.
2025-09-01 10:38:50,036 - INFO - ipython-input-2291996958 - Downloading 'bicycle' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bicycle.bin to ./data/full_binary_bicycle.bin...
INFO:__main__:Downloading 'bicycle' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bicycle.bin to ./data/full_binary_bicycle.bin...


bicycle:   0%|          | 0.00/17.6M [00:00<?, ?iB/s]

2025-09-01 10:38:50,713 - INFO - ipython-input-2291996958 - Successfully downloaded 'bicycle'.
INFO:__main__:Successfully downloaded 'bicycle'.
2025-09-01 10:38:50,716 - INFO - ipython-input-2291996958 - Downloading 'bird' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bird.bin to ./data/full_binary_bird.bin...
INFO:__main__:Downloading 'bird' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bird.bin to ./data/full_binary_bird.bin...


bird:   0%|          | 0.00/16.4M [00:00<?, ?iB/s]

2025-09-01 10:38:51,449 - INFO - ipython-input-2291996958 - Successfully downloaded 'bird'.
INFO:__main__:Successfully downloaded 'bird'.
2025-09-01 10:38:51,451 - INFO - ipython-input-2291996958 - Downloading 'book' from https://storage.googleapis.com/quickdraw_dataset/full/binary/book.bin to ./data/full_binary_book.bin...
INFO:__main__:Downloading 'book' from https://storage.googleapis.com/quickdraw_dataset/full/binary/book.bin to ./data/full_binary_book.bin...


book:   0%|          | 0.00/13.5M [00:00<?, ?iB/s]

2025-09-01 10:38:52,056 - INFO - ipython-input-2291996958 - Successfully downloaded 'book'.
INFO:__main__:Successfully downloaded 'book'.
2025-09-01 10:38:52,061 - INFO - ipython-input-2291996958 - Downloading 'bread' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bread.bin to ./data/full_binary_bread.bin...
INFO:__main__:Downloading 'bread' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bread.bin to ./data/full_binary_bread.bin...


bread:   0%|          | 0.00/9.28M [00:00<?, ?iB/s]

2025-09-01 10:38:52,444 - INFO - ipython-input-2291996958 - Successfully downloaded 'bread'.
INFO:__main__:Successfully downloaded 'bread'.
2025-09-01 10:38:52,447 - INFO - ipython-input-2291996958 - Downloading 'bus' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bus.bin to ./data/full_binary_bus.bin...
INFO:__main__:Downloading 'bus' from https://storage.googleapis.com/quickdraw_dataset/full/binary/bus.bin to ./data/full_binary_bus.bin...


bus:   0%|          | 0.00/24.1M [00:00<?, ?iB/s]

2025-09-01 10:38:53,141 - INFO - ipython-input-2291996958 - Successfully downloaded 'bus'.
INFO:__main__:Successfully downloaded 'bus'.
2025-09-01 10:38:53,143 - INFO - ipython-input-2291996958 - Downloading 'elbow' from https://storage.googleapis.com/quickdraw_dataset/full/binary/elbow.bin to ./data/full_binary_elbow.bin...
INFO:__main__:Downloading 'elbow' from https://storage.googleapis.com/quickdraw_dataset/full/binary/elbow.bin to ./data/full_binary_elbow.bin...


elbow:   0%|          | 0.00/9.12M [00:00<?, ?iB/s]

2025-09-01 10:38:53,599 - INFO - ipython-input-2291996958 - Successfully downloaded 'elbow'.
INFO:__main__:Successfully downloaded 'elbow'.
2025-09-01 10:38:53,602 - INFO - ipython-input-2291996958 - Downloading 'ear' from https://storage.googleapis.com/quickdraw_dataset/full/binary/ear.bin to ./data/full_binary_ear.bin...
INFO:__main__:Downloading 'ear' from https://storage.googleapis.com/quickdraw_dataset/full/binary/ear.bin to ./data/full_binary_ear.bin...


ear:   0%|          | 0.00/9.70M [00:00<?, ?iB/s]

2025-09-01 10:38:54,228 - INFO - ipython-input-2291996958 - Successfully downloaded 'ear'.
INFO:__main__:Successfully downloaded 'ear'.
2025-09-01 10:38:54,232 - INFO - ipython-input-2291996958 - Downloading 'camera' from https://storage.googleapis.com/quickdraw_dataset/full/binary/camera.bin to ./data/full_binary_camera.bin...
INFO:__main__:Downloading 'camera' from https://storage.googleapis.com/quickdraw_dataset/full/binary/camera.bin to ./data/full_binary_camera.bin...


camera:   0%|          | 0.00/13.0M [00:00<?, ?iB/s]

2025-09-01 10:38:54,657 - INFO - ipython-input-2291996958 - Successfully downloaded 'camera'.
INFO:__main__:Successfully downloaded 'camera'.
2025-09-01 10:38:54,659 - INFO - ipython-input-2291996958 - Downloading 'car' from https://storage.googleapis.com/quickdraw_dataset/full/binary/car.bin to ./data/full_binary_car.bin...
INFO:__main__:Downloading 'car' from https://storage.googleapis.com/quickdraw_dataset/full/binary/car.bin to ./data/full_binary_car.bin...


car:   0%|          | 0.00/23.3M [00:00<?, ?iB/s]

2025-09-01 10:38:55,466 - INFO - ipython-input-2291996958 - Successfully downloaded 'car'.
INFO:__main__:Successfully downloaded 'car'.
2025-09-01 10:38:55,473 - INFO - ipython-input-2291996958 - Downloading 'chair' from https://storage.googleapis.com/quickdraw_dataset/full/binary/chair.bin to ./data/full_binary_chair.bin...
INFO:__main__:Downloading 'chair' from https://storage.googleapis.com/quickdraw_dataset/full/binary/chair.bin to ./data/full_binary_chair.bin...


chair:   0%|          | 0.00/16.8M [00:00<?, ?iB/s]

2025-09-01 10:38:55,883 - INFO - ipython-input-2291996958 - Successfully downloaded 'chair'.
INFO:__main__:Successfully downloaded 'chair'.
2025-09-01 10:38:55,885 - INFO - ipython-input-2291996958 - Downloading 'clock' from https://storage.googleapis.com/quickdraw_dataset/full/binary/clock.bin to ./data/full_binary_clock.bin...
INFO:__main__:Downloading 'clock' from https://storage.googleapis.com/quickdraw_dataset/full/binary/clock.bin to ./data/full_binary_clock.bin...


clock:   0%|          | 0.00/12.3M [00:00<?, ?iB/s]

2025-09-01 10:38:56,418 - INFO - ipython-input-2291996958 - Successfully downloaded 'clock'.
INFO:__main__:Successfully downloaded 'clock'.
2025-09-01 10:38:56,420 - INFO - ipython-input-2291996958 - Downloading 'cloud' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cloud.bin to ./data/full_binary_cloud.bin...
INFO:__main__:Downloading 'cloud' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cloud.bin to ./data/full_binary_cloud.bin...


cloud:   0%|          | 0.00/12.7M [00:00<?, ?iB/s]

2025-09-01 10:38:56,974 - INFO - ipython-input-2291996958 - Successfully downloaded 'cloud'.
INFO:__main__:Successfully downloaded 'cloud'.
2025-09-01 10:38:56,976 - INFO - ipython-input-2291996958 - Downloading 'hand' from https://storage.googleapis.com/quickdraw_dataset/full/binary/hand.bin to ./data/full_binary_hand.bin...
INFO:__main__:Downloading 'hand' from https://storage.googleapis.com/quickdraw_dataset/full/binary/hand.bin to ./data/full_binary_hand.bin...


hand:   0%|          | 0.00/29.7M [00:00<?, ?iB/s]

2025-09-01 10:38:57,802 - INFO - ipython-input-2291996958 - Successfully downloaded 'hand'.
INFO:__main__:Successfully downloaded 'hand'.
2025-09-01 10:38:57,804 - INFO - ipython-input-2291996958 - Downloading 'computer' from https://storage.googleapis.com/quickdraw_dataset/full/binary/computer.bin to ./data/full_binary_computer.bin...
INFO:__main__:Downloading 'computer' from https://storage.googleapis.com/quickdraw_dataset/full/binary/computer.bin to ./data/full_binary_computer.bin...


computer:   0%|          | 0.00/13.6M [00:00<?, ?iB/s]

2025-09-01 10:38:58,448 - INFO - ipython-input-2291996958 - Successfully downloaded 'computer'.
INFO:__main__:Successfully downloaded 'computer'.
2025-09-01 10:38:58,450 - INFO - ipython-input-2291996958 - Downloading 'cookie' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cookie.bin to ./data/full_binary_cookie.bin...
INFO:__main__:Downloading 'cookie' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cookie.bin to ./data/full_binary_cookie.bin...


cookie:   0%|          | 0.00/19.7M [00:00<?, ?iB/s]

2025-09-01 10:38:59,244 - INFO - ipython-input-2291996958 - Successfully downloaded 'cookie'.
INFO:__main__:Successfully downloaded 'cookie'.
2025-09-01 10:38:59,249 - INFO - ipython-input-2291996958 - Downloading 'cow' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cow.bin to ./data/full_binary_cow.bin...
INFO:__main__:Downloading 'cow' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cow.bin to ./data/full_binary_cow.bin...


cow:   0%|          | 0.00/24.4M [00:00<?, ?iB/s]

2025-09-01 10:38:59,818 - INFO - ipython-input-2291996958 - Successfully downloaded 'cow'.
INFO:__main__:Successfully downloaded 'cow'.
2025-09-01 10:38:59,820 - INFO - ipython-input-2291996958 - Downloading 'crayon' from https://storage.googleapis.com/quickdraw_dataset/full/binary/crayon.bin to ./data/full_binary_crayon.bin...
INFO:__main__:Downloading 'crayon' from https://storage.googleapis.com/quickdraw_dataset/full/binary/crayon.bin to ./data/full_binary_crayon.bin...


crayon:   0%|          | 0.00/10.8M [00:00<?, ?iB/s]

2025-09-01 10:39:00,367 - INFO - ipython-input-2291996958 - Successfully downloaded 'crayon'.
INFO:__main__:Successfully downloaded 'crayon'.
2025-09-01 10:39:00,373 - INFO - ipython-input-2291996958 - Downloading 'cup' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cup.bin to ./data/full_binary_cup.bin...
INFO:__main__:Downloading 'cup' from https://storage.googleapis.com/quickdraw_dataset/full/binary/cup.bin to ./data/full_binary_cup.bin...


cup:   0%|          | 0.00/12.1M [00:00<?, ?iB/s]

2025-09-01 10:39:00,804 - INFO - ipython-input-2291996958 - Successfully downloaded 'cup'.
INFO:__main__:Successfully downloaded 'cup'.
2025-09-01 10:39:00,806 - INFO - ipython-input-2291996958 - Downloading 'eraser' from https://storage.googleapis.com/quickdraw_dataset/full/binary/eraser.bin to ./data/full_binary_eraser.bin...
INFO:__main__:Downloading 'eraser' from https://storage.googleapis.com/quickdraw_dataset/full/binary/eraser.bin to ./data/full_binary_eraser.bin...


eraser:   0%|          | 0.00/10.4M [00:00<?, ?iB/s]

2025-09-01 10:39:01,573 - INFO - ipython-input-2291996958 - Successfully downloaded 'eraser'.
INFO:__main__:Successfully downloaded 'eraser'.
2025-09-01 10:39:01,577 - INFO - ipython-input-2291996958 - Downloading 'carrot' from https://storage.googleapis.com/quickdraw_dataset/full/binary/carrot.bin to ./data/full_binary_carrot.bin...
INFO:__main__:Downloading 'carrot' from https://storage.googleapis.com/quickdraw_dataset/full/binary/carrot.bin to ./data/full_binary_carrot.bin...


carrot:   0%|          | 0.00/13.5M [00:00<?, ?iB/s]

2025-09-01 10:39:02,189 - INFO - ipython-input-2291996958 - Successfully downloaded 'carrot'.
INFO:__main__:Successfully downloaded 'carrot'.
2025-09-01 10:39:02,191 - INFO - ipython-input-2291996958 - Downloading 'drums' from https://storage.googleapis.com/quickdraw_dataset/full/binary/drums.bin to ./data/full_binary_drums.bin...
INFO:__main__:Downloading 'drums' from https://storage.googleapis.com/quickdraw_dataset/full/binary/drums.bin to ./data/full_binary_drums.bin...


drums:   0%|          | 0.00/18.8M [00:00<?, ?iB/s]

2025-09-01 10:39:02,840 - INFO - ipython-input-2291996958 - Successfully downloaded 'drums'.
INFO:__main__:Successfully downloaded 'drums'.
2025-09-01 10:39:02,842 - INFO - ipython-input-2291996958 - Downloading 'eye' from https://storage.googleapis.com/quickdraw_dataset/full/binary/eye.bin to ./data/full_binary_eye.bin...
INFO:__main__:Downloading 'eye' from https://storage.googleapis.com/quickdraw_dataset/full/binary/eye.bin to ./data/full_binary_eye.bin...


eye:   0%|          | 0.00/16.1M [00:00<?, ?iB/s]

2025-09-01 10:39:03,425 - INFO - ipython-input-2291996958 - Successfully downloaded 'eye'.
INFO:__main__:Successfully downloaded 'eye'.
2025-09-01 10:39:03,427 - INFO - ipython-input-2291996958 - Downloading 'knife' from https://storage.googleapis.com/quickdraw_dataset/full/binary/knife.bin to ./data/full_binary_knife.bin...
INFO:__main__:Downloading 'knife' from https://storage.googleapis.com/quickdraw_dataset/full/binary/knife.bin to ./data/full_binary_knife.bin...


knife:   0%|          | 0.00/12.0M [00:00<?, ?iB/s]

2025-09-01 10:39:03,955 - INFO - ipython-input-2291996958 - Successfully downloaded 'knife'.
INFO:__main__:Successfully downloaded 'knife'.
2025-09-01 10:39:03,960 - INFO - ipython-input-2291996958 - Download process finished.
INFO:__main__:Download process finished.


# QuickDrawBinaryDataset Loader and Indexer

In [8]:
# --- QuickDraw Binary Data Reading Functions (from user - Unchanged) ---
def unpack_drawing(file_handle):
    try:
        key_id, = unpack('Q', file_handle.read(8))
        country_code, = unpack('2s', file_handle.read(2))
        recognized, = unpack('b', file_handle.read(1))
        timestamp, = unpack('I', file_handle.read(4))
        n_strokes, = unpack('H', file_handle.read(2))
        image_strokes = []
        for _ in range(n_strokes):
            n_points, = unpack('H', file_handle.read(2))
            fmt = str(n_points) + 'B'
            if n_points == 0:
                image_strokes.append((tuple(), tuple()))
                continue

            x_bytes = file_handle.read(n_points)
            y_bytes = file_handle.read(n_points)

            if len(x_bytes) < n_points or len(y_bytes) < n_points:
                logger.error(f"Insufficient data for stroke points. Expected {n_points}, got {len(x_bytes)} for x, {len(y_bytes)} for y. Skipping drawing.")
                raise struct.error("Insufficient data for stroke points, likely corrupted drawing record.")

            x = unpack(fmt, x_bytes)
            y = unpack(fmt, y_bytes)
            image_strokes.append((x, y))

        return {
            'key_id': key_id,
            'country_code': country_code,
            'recognized': recognized,
            'timestamp': timestamp,
            'image': image_strokes
        }
    except struct.error as e:
        logger.debug(f"Struct error during unpack_drawing: {e}. File pointer at {file_handle.tell() if hasattr(file_handle, 'tell') else 'N/A'}")
        raise
    except Exception as e:
        logger.error(f"Unexpected error in unpack_drawing: {e}")
        raise


def unpack_drawings(filename): # Unchanged, used by indexing if needed elsewhere
    file_size = os.path.getsize(filename)
    with open(filename, 'rb') as f, tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Unpacking {os.path.basename(filename)}", leave=False) as pbar:
        while True:
            try:
                start_pos = f.tell()
                if start_pos >= file_size:
                    break
                yield unpack_drawing(f)
                pbar.update(f.tell() - start_pos)
            except struct.error:
                logger.debug(f"Struct.error in unpack_drawings, likely end of file or data. File: {filename}")
                break
            except EOFError:
                logger.debug(f"EOFError in unpack_drawings. File: {filename}")
                break

def precompute_all_indices(root_dir, categories):
    """
    Precompute indices for all categories in the binary dataset.
    This function reads each binary file, extracts drawing offsets, and saves them as index files.
    """
    os.makedirs(root_dir, exist_ok=True)

    for category in tqdm(categories, desc="Precomputing indices"):
        # Convert category name for filename
        category_file = category.replace(' ', '_')
        filepath = os.path.join(root_dir, f"full_binary_{category_file}.bin")
        index_path = os.path.join(root_dir, f"full_binary_{category_file}.idx")

        if os.path.exists(index_path):
            logger.info(f"Index for {category} already exists at {index_path}")
            continue

        if not os.path.exists(filepath):
            logger.warning(f"Binary file for {category} not found at {filepath}, skipping index creation")
            continue

        logger.info(f"Creating index for {category}...")
        drawing_offsets = []

        file_size = os.path.getsize(filepath)
        with open(filepath, 'rb') as f:
            with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Indexing {os.path.basename(filepath)}") as pbar:
                while True:
                    current_pos = f.tell()
                    if current_pos >= file_size:
                        break
                    try:
                        # Record the current position
                        drawing_offsets.append(current_pos)
                        # Read one drawing to advance the pointer
                        _ = unpack_drawing(f)
                        pbar.update(f.tell() - current_pos)
                    except struct.error:
                        logger.debug(f"Finished indexing or encountered struct.error at offset {current_pos}")
                        break
                    except EOFError:
                        logger.debug(f"EOFError encountered while indexing at offset {current_pos}")
                        break
                    except Exception as e:
                        logger.error(f"Unexpected error during indexing at offset {current_pos}: {e}")
                        break

        # Save to disk
        with open(index_path, 'wb') as f:
            pickle.dump(drawing_offsets, f)

        logger.info(f"Saved index with {len(drawing_offsets)} entries to {index_path}")

    logger.info(f"Precomputed indices for {len(categories)} categories")
    return True

# --- Custom QuickDraw Dataset from Local Binary Files (REFACTORED) ---
class QuickDrawBinaryDataset(Dataset):
    IMAGE_SIZE = (256, 256)
    LINE_WIDTH = 2
    # Static dictionary to store cached indices for each file path
    _cached_drawing_offsets = {}

    def __init__(self, root, category, transform=None, cache_size=QUICKDRAW_CACHE_SIZE):
        self.root = root
        # Sanitize category name for filename (replace spaces with underscores)
        self.category = category.replace(' ', '_')
        self.transform = transform
        self.filepath = os.path.join(self.root, f"full_binary_{self.category}.bin")
        self.index_path = os.path.join(self.root, f"full_binary_{self.category}.idx")

        self.cache_size = cache_size
        self.worker_caches = {}  # Dictionary to store worker-specific caches

        if not os.path.exists(self.filepath):
            raise FileNotFoundError(
                f"Dataset binary file not found: {self.filepath}. Please ensure it exists."
            )

        # Load or create the index for this file path
        self.drawing_offsets = self._get_or_create_index()

        if not self.drawing_offsets:
            logger.warning(f"No drawings were indexed for category {self.category} from {self.filepath}.")
        else:
            logger.info(f"Successfully loaded or indexed {len(self.drawing_offsets)} drawings for {self.category}. Cache capacity: {self.cache_size} items.")

    def _get_or_create_index(self):
        # First check in-memory cache
        if self.filepath in self._cached_drawing_offsets:
            logger.debug(f"Using in-memory cached index for {self.filepath}")
            return self._cached_drawing_offsets[self.filepath]

        # Then check for pre-computed index file
        if os.path.exists(self.index_path):
            try:
                logger.info(f"Loading pre-computed index from {self.index_path}")
                with open(self.index_path, 'rb') as f:
                    drawing_offsets = pickle.load(f)
                # Store in memory cache
                self._cached_drawing_offsets[self.filepath] = drawing_offsets
                return drawing_offsets
            except Exception as e:
                logger.error(f"Error loading pre-computed index from {self.index_path}: {e}")
                # Fall back to creating index

        # If we get here, we need to create the index
        logger.info(f"Indexing drawings from {self.filepath} for category {self.category}...")
        drawing_offsets = []

        file_size = os.path.getsize(self.filepath)
        idx_file_handle = None
        try:
            idx_file_handle = open(self.filepath, 'rb')
            with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Indexing {os.path.basename(self.filepath)}", leave=False) as pbar:
                while True:
                    current_pos = idx_file_handle.tell()
                    if current_pos >= file_size:
                        break
                    try:
                        # Read one drawing just to advance the pointer and validate structure
                        # We use the external unpack_drawing function here
                        drawing_offsets.append(current_pos)
                        _ = unpack_drawing(idx_file_handle)
                        pbar.update(idx_file_handle.tell() - current_pos)
                    except struct.error:
                        logger.debug(f"Finished indexing or encountered struct.error at offset {current_pos} in {self.filepath}. Total indexed: {len(drawing_offsets)}")
                        break
                    except EOFError:
                        logger.debug(f"EOFError encountered while indexing {self.filepath} at offset {current_pos}. Total indexed: {len(drawing_offsets)}")
                        break
                    except Exception as e:
                        logger.error(f"Unexpected error during indexing of {self.filepath} at offset {current_pos}: {e}. Stopping indexing for this file.")
                        break
        finally:
            if idx_file_handle:
                idx_file_handle.close()

        # Store the index in the memory cache
        self._cached_drawing_offsets[self.filepath] = drawing_offsets

        # Also save to disk for future use
        try:
            with open(self.index_path, 'wb') as f:
                pickle.dump(drawing_offsets, f)
                logger.info(f"Saved index with {len(drawing_offsets)} entries to {self.index_path}")
        except Exception as e:
            logger.error(f"Failed to save index to {self.index_path}: {e}")

        return drawing_offsets

    def _render_drawing_to_image(self, drawing_strokes):
        image = Image.new("L", self.IMAGE_SIZE, "white")
        draw = ImageDraw.Draw(image)
        for stroke_x, stroke_y in drawing_strokes:
            if not stroke_x or not stroke_y:
                continue
            if len(stroke_x) == 1:
                draw.point((int(stroke_x[0]), int(stroke_y[0])), fill="black")
            else:
                points = list(zip(stroke_x, stroke_y))
                draw.line(points, fill="black", width=self.LINE_WIDTH)
        return image

    def __len__(self):
        return len(self.drawing_offsets)

    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self.drawing_offsets):
            raise IndexError(f"Index {idx} out of bounds for {len(self.drawing_offsets)} drawings.")

        # Get worker info for process-specific caching
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info else 0

        # Create worker-specific cache if it doesn't exist
        if worker_id not in self.worker_caches:
            self.worker_caches[worker_id] = collections.OrderedDict()

        # Use worker-specific cache
        worker_cache = self.worker_caches[worker_id]

        drawing_data = None
        if idx in worker_cache:
            drawing_data = worker_cache[idx]
            worker_cache.move_to_end(idx)  # Mark as recently used
        else:
            offset = self.drawing_offsets[idx]
            try:
                # Open file, seek, read one drawing, then close.
                # This is safer for multiprocessing in DataLoader.
                with open(self.filepath, 'rb') as f:
                    f.seek(offset)
                    # Use the external unpack_drawing function here
                    drawing_data = unpack_drawing(f)
            except Exception as e:
                logger.error(f"Error reading drawing at index {idx}, offset {offset} from {self.filepath}: {e}")
                raise IOError(f"Failed to load drawing {idx} for {self.category}") from e

            if self.cache_size > 0:
                worker_cache[idx] = drawing_data
                if len(worker_cache) > self.cache_size:
                    worker_cache.popitem(last=False)  # Remove oldest item (LRU)

        if drawing_data is None: # Should not happen if logic is correct
             raise RuntimeError(f"Drawing data for index {idx} could not be retrieved.")

        pil_image = self._render_drawing_to_image(drawing_data['image'])

        if self.transform:
            pil_image = self.transform(pil_image)

        return pil_image, self.category

### Generate CSV of the 51 QuickDraw Categories

In [9]:
def generate_category_counts_report():
    '''
    Generates a CSV report of all QuickDraw categories and their sample counts.
    Generating category counts report...")
    '''
    category_stats = []

    for category_name in tqdm(QUICKDRAW_CATEGORIES, desc="Counting category samples"):
        try:
            # Create dataset for this category
            filepath = os.path.join(BINARY_DATA_ROOT, f"full_binary_{category_name.replace(' ', '_')}.bin")
            if not os.path.exists(filepath):
                category_stats.append({
                    'Category': category_name,
                    'Total Samples': 0,
                    'Error': 'File not found'
                })
                continue

            dataset = QuickDrawBinaryDataset(
                root=BINARY_DATA_ROOT,
                category=category_name,
                transform=None,
                cache_size=QUICKDRAW_CACHE_SIZE
            )

            total_count = len(dataset)

            # Calculate the actual split sizes based on our splitting logic
            actual_train = min(NUM_TRAIN_SAMPLES_PER_CATEGORY, int(total_count * 0.7))
            remaining = total_count - actual_train
            actual_val = min(int(NUM_TRAIN_SAMPLES_PER_CATEGORY * 0.2), int(remaining * 0.5))
            actual_test = min(NUM_TEST_SAMPLES_PER_CATEGORY, remaining - actual_val)

            category_stats.append({
                'Category': category_name,
                'Total Samples': total_count,
                'Training Samples': actual_train,
                'Validation Samples': actual_val,
                'Test Samples': actual_test
            })

            logger.info(f"Category: {category_name}, Total: {total_count}, Train: {actual_train}, Val: {actual_val}, Test: {actual_test}")

        except Exception as e:
            logger.error(f"Error counting samples for category {category_name}: {e}")
            category_stats.append({
                'Category': category_name,
                'Total Samples': 0,
                'Error': str(e)
            })

    # Save to CSV
    csv_filename = f"quickdraw_category_counts.csv"

    try:
        with open(csv_filename, 'w', newline='') as csvfile:
            # Define fieldnames including all possible columns
            fieldnames = ['Category', 'Total Samples', 'Training Samples', 'Validation Samples', 'Test Samples', 'Error']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for stats in category_stats:
                writer.writerow(stats)
        logger.info(f"Category counts saved to {csv_filename}")
    except IOError as e:
        logger.error(f"Could not save category counts to CSV: {e}")

    return category_stats



# Generate category counts report
category_stats = generate_category_counts_report()


Counting category samples:   0%|          | 0/51 [00:00<?, ?it/s]

2025-09-01 10:39:04,300 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_apple.bin for category apple...
INFO:__main__:Indexing drawings from ./data/full_binary_apple.bin for category apple...


Indexing full_binary_apple.bin:   0%|          | 0.00/13.8M [00:00<?, ?B/s]

2025-09-01 10:39:05,116 - INFO - ipython-input-2928541081 - Saved index with 144722 entries to ./data/full_binary_apple.idx
INFO:__main__:Saved index with 144722 entries to ./data/full_binary_apple.idx
2025-09-01 10:39:05,117 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 144722 drawings for apple. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 144722 drawings for apple. Cache capacity: 50000 items.
2025-09-01 10:39:05,119 - INFO - ipython-input-2988690619 - Category: apple, Total: 144722, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: apple, Total: 144722, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:05,121 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_cat.bin for category cat...
INFO:__main__:Indexing drawings from ./data/full_binary_cat.bin for category cat...


Indexing full_binary_cat.bin:   0%|          | 0.00/19.6M [00:00<?, ?B/s]

2025-09-01 10:39:06,542 - INFO - ipython-input-2928541081 - Saved index with 123202 entries to ./data/full_binary_cat.idx
INFO:__main__:Saved index with 123202 entries to ./data/full_binary_cat.idx
2025-09-01 10:39:06,544 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 123202 drawings for cat. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 123202 drawings for cat. Cache capacity: 50000 items.
2025-09-01 10:39:06,546 - INFO - ipython-input-2988690619 - Category: cat, Total: 123202, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: cat, Total: 123202, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:06,549 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_dog.bin for category dog...
INFO:__main__:Indexing drawings from ./data/full_binary_dog.bin for category dog...


Indexing full_binary_dog.bin:   0%|          | 0.00/23.5M [00:00<?, ?B/s]

2025-09-01 10:39:08,040 - INFO - ipython-input-2928541081 - Saved index with 152159 entries to ./data/full_binary_dog.idx
INFO:__main__:Saved index with 152159 entries to ./data/full_binary_dog.idx
2025-09-01 10:39:08,042 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 152159 drawings for dog. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 152159 drawings for dog. Cache capacity: 50000 items.
2025-09-01 10:39:08,043 - INFO - ipython-input-2988690619 - Category: dog, Total: 152159, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: dog, Total: 152159, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:08,046 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_door.bin for category door...
INFO:__main__:Indexing drawings from ./data/full_binary_door.bin for category door...


Indexing full_binary_door.bin:   0%|          | 0.00/8.54M [00:00<?, ?B/s]

2025-09-01 10:39:08,678 - INFO - ipython-input-2928541081 - Saved index with 120230 entries to ./data/full_binary_door.idx
INFO:__main__:Saved index with 120230 entries to ./data/full_binary_door.idx
2025-09-01 10:39:08,679 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 120230 drawings for door. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 120230 drawings for door. Cache capacity: 50000 items.
2025-09-01 10:39:08,681 - INFO - ipython-input-2988690619 - Category: door, Total: 120230, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: door, Total: 120230, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:08,685 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_elephant.bin for category elephant...
INFO:__main__:Indexing drawings from ./data/full_binary_elephant.bin for category elephant...


Indexing full_binary_elephant.bin:   0%|          | 0.00/18.8M [00:00<?, ?B/s]

2025-09-01 10:39:09,739 - INFO - ipython-input-2928541081 - Saved index with 126969 entries to ./data/full_binary_elephant.idx
INFO:__main__:Saved index with 126969 entries to ./data/full_binary_elephant.idx
2025-09-01 10:39:09,740 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 126969 drawings for elephant. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 126969 drawings for elephant. Cache capacity: 50000 items.
2025-09-01 10:39:09,741 - INFO - ipython-input-2988690619 - Category: elephant, Total: 126969, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: elephant, Total: 126969, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:09,745 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_fish.bin for category fish...
INFO:__main__:Indexing drawings from ./data/full_binary_fish.bin for category fish...


Indexing full_binary_fish.bin:   0%|          | 0.00/12.1M [00:00<?, ?B/s]

2025-09-01 10:39:10,547 - INFO - ipython-input-2928541081 - Saved index with 134150 entries to ./data/full_binary_fish.idx
INFO:__main__:Saved index with 134150 entries to ./data/full_binary_fish.idx
2025-09-01 10:39:10,549 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 134150 drawings for fish. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 134150 drawings for fish. Cache capacity: 50000 items.
2025-09-01 10:39:10,551 - INFO - ipython-input-2988690619 - Category: fish, Total: 134150, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: fish, Total: 134150, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:10,554 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_flower.bin for category flower...
INFO:__main__:Indexing drawings from ./data/full_binary_flower.bin for category flower...


Indexing full_binary_flower.bin:   0%|          | 0.00/21.5M [00:00<?, ?B/s]

2025-09-01 10:39:11,735 - INFO - ipython-input-2928541081 - Saved index with 144818 entries to ./data/full_binary_flower.idx
INFO:__main__:Saved index with 144818 entries to ./data/full_binary_flower.idx
2025-09-01 10:39:11,736 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 144818 drawings for flower. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 144818 drawings for flower. Cache capacity: 50000 items.
2025-09-01 10:39:11,738 - INFO - ipython-input-2988690619 - Category: flower, Total: 144818, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: flower, Total: 144818, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:11,741 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_table.bin for category table...
INFO:__main__:Indexing drawings from ./data/full_binary_table.bin for category table...


Indexing full_binary_table.bin:   0%|          | 0.00/9.75M [00:00<?, ?B/s]

2025-09-01 10:39:13,251 - INFO - ipython-input-2928541081 - Saved index with 128021 entries to ./data/full_binary_table.idx
INFO:__main__:Saved index with 128021 entries to ./data/full_binary_table.idx
2025-09-01 10:39:13,257 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 128021 drawings for table. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 128021 drawings for table. Cache capacity: 50000 items.
2025-09-01 10:39:13,259 - INFO - ipython-input-2988690619 - Category: table, Total: 128021, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: table, Total: 128021, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:13,262 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_grass.bin for category grass...
INFO:__main__:Indexing drawings from ./data/full_binary_grass.bin for category grass...


Indexing full_binary_grass.bin:   0%|          | 0.00/11.9M [00:00<?, ?B/s]

2025-09-01 10:39:14,399 - INFO - ipython-input-2928541081 - Saved index with 123071 entries to ./data/full_binary_grass.idx
INFO:__main__:Saved index with 123071 entries to ./data/full_binary_grass.idx
2025-09-01 10:39:14,401 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 123071 drawings for grass. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 123071 drawings for grass. Cache capacity: 50000 items.
2025-09-01 10:39:14,402 - INFO - ipython-input-2988690619 - Category: grass, Total: 123071, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: grass, Total: 123071, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:14,405 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_house.bin for category house...
INFO:__main__:Indexing drawings from ./data/full_binary_house.bin for category house...


Indexing full_binary_house.bin:   0%|          | 0.00/10.9M [00:00<?, ?B/s]

2025-09-01 10:39:15,222 - INFO - ipython-input-2928541081 - Saved index with 135420 entries to ./data/full_binary_house.idx
INFO:__main__:Saved index with 135420 entries to ./data/full_binary_house.idx
2025-09-01 10:39:15,224 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 135420 drawings for house. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 135420 drawings for house. Cache capacity: 50000 items.
2025-09-01 10:39:15,226 - INFO - ipython-input-2988690619 - Category: house, Total: 135420, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: house, Total: 135420, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:15,229 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_ice_cream.bin for category ice_cream...
INFO:__main__:Indexing drawings from ./data/full_binary_ice_cream.bin for category ice_cream...


Indexing full_binary_ice_cream.bin:   0%|          | 0.00/11.9M [00:00<?, ?B/s]

2025-09-01 10:39:16,016 - INFO - ipython-input-2928541081 - Saved index with 123133 entries to ./data/full_binary_ice_cream.idx
INFO:__main__:Saved index with 123133 entries to ./data/full_binary_ice_cream.idx
2025-09-01 10:39:16,017 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 123133 drawings for ice_cream. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 123133 drawings for ice_cream. Cache capacity: 50000 items.
2025-09-01 10:39:16,019 - INFO - ipython-input-2988690619 - Category: ice cream, Total: 123133, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: ice cream, Total: 123133, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:16,022 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_circle.bin for category circle...
INFO:__main__:Indexing drawings from ./data/full_binary_circle.bin for category circle...


Indexing full_binary_circle.bin:   0%|          | 0.00/8.98M [00:00<?, ?B/s]

2025-09-01 10:39:16,553 - INFO - ipython-input-2928541081 - Saved index with 122876 entries to ./data/full_binary_circle.idx
INFO:__main__:Saved index with 122876 entries to ./data/full_binary_circle.idx
2025-09-01 10:39:16,554 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 122876 drawings for circle. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 122876 drawings for circle. Cache capacity: 50000 items.
2025-09-01 10:39:16,557 - INFO - ipython-input-2988690619 - Category: circle, Total: 122876, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: circle, Total: 122876, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:16,560 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_key.bin for category key...
INFO:__main__:Indexing drawings from ./data/full_binary_key.bin for category key...


Indexing full_binary_key.bin:   0%|          | 0.00/16.9M [00:00<?, ?B/s]

2025-09-01 10:39:17,566 - INFO - ipython-input-2928541081 - Saved index with 160965 entries to ./data/full_binary_key.idx
INFO:__main__:Saved index with 160965 entries to ./data/full_binary_key.idx
2025-09-01 10:39:17,567 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 160965 drawings for key. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 160965 drawings for key. Cache capacity: 50000 items.
2025-09-01 10:39:17,569 - INFO - ipython-input-2988690619 - Category: key, Total: 160965, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: key, Total: 160965, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:17,571 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_lion.bin for category lion...
INFO:__main__:Indexing drawings from ./data/full_binary_lion.bin for category lion...


Indexing full_binary_lion.bin:   0%|          | 0.00/24.7M [00:00<?, ?B/s]

2025-09-01 10:39:18,814 - INFO - ipython-input-2928541081 - Saved index with 120949 entries to ./data/full_binary_lion.idx
INFO:__main__:Saved index with 120949 entries to ./data/full_binary_lion.idx
2025-09-01 10:39:18,816 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 120949 drawings for lion. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 120949 drawings for lion. Cache capacity: 50000 items.
2025-09-01 10:39:18,818 - INFO - ipython-input-2988690619 - Category: lion, Total: 120949, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: lion, Total: 120949, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:18,821 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_moon.bin for category moon...
INFO:__main__:Indexing drawings from ./data/full_binary_moon.bin for category moon...


Indexing full_binary_moon.bin:   0%|          | 0.00/11.6M [00:00<?, ?B/s]

2025-09-01 10:39:19,502 - INFO - ipython-input-2928541081 - Saved index with 121661 entries to ./data/full_binary_moon.idx
INFO:__main__:Saved index with 121661 entries to ./data/full_binary_moon.idx
2025-09-01 10:39:19,504 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 121661 drawings for moon. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 121661 drawings for moon. Cache capacity: 50000 items.
2025-09-01 10:39:19,506 - INFO - ipython-input-2988690619 - Category: moon, Total: 121661, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: moon, Total: 121661, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:19,509 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_nose.bin for category nose...
INFO:__main__:Indexing drawings from ./data/full_binary_nose.bin for category nose...


Indexing full_binary_nose.bin:   0%|          | 0.00/13.8M [00:00<?, ?B/s]

2025-09-01 10:39:20,485 - INFO - ipython-input-2928541081 - Saved index with 197573 entries to ./data/full_binary_nose.idx
INFO:__main__:Saved index with 197573 entries to ./data/full_binary_nose.idx
2025-09-01 10:39:20,487 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 197573 drawings for nose. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 197573 drawings for nose. Cache capacity: 50000 items.
2025-09-01 10:39:20,488 - INFO - ipython-input-2988690619 - Category: nose, Total: 197573, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: nose, Total: 197573, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:20,490 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_pencil.bin for category pencil...
INFO:__main__:Indexing drawings from ./data/full_binary_pencil.bin for category pencil...


Indexing full_binary_pencil.bin:   0%|          | 0.00/10.1M [00:00<?, ?B/s]

2025-09-01 10:39:21,340 - INFO - ipython-input-2928541081 - Saved index with 122001 entries to ./data/full_binary_pencil.idx
INFO:__main__:Saved index with 122001 entries to ./data/full_binary_pencil.idx
2025-09-01 10:39:21,341 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 122001 drawings for pencil. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 122001 drawings for pencil. Cache capacity: 50000 items.
2025-09-01 10:39:21,344 - INFO - ipython-input-2988690619 - Category: pencil, Total: 122001, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: pencil, Total: 122001, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:21,347 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_rabbit.bin for category rabbit...
INFO:__main__:Indexing drawings from ./data/full_binary_rabbit.bin for category rabbit...


Indexing full_binary_rabbit.bin:   0%|          | 0.00/25.3M [00:00<?, ?B/s]

2025-09-01 10:39:22,862 - INFO - ipython-input-2928541081 - Saved index with 155288 entries to ./data/full_binary_rabbit.idx
INFO:__main__:Saved index with 155288 entries to ./data/full_binary_rabbit.idx
2025-09-01 10:39:22,864 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 155288 drawings for rabbit. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 155288 drawings for rabbit. Cache capacity: 50000 items.
2025-09-01 10:39:22,866 - INFO - ipython-input-2988690619 - Category: rabbit, Total: 155288, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: rabbit, Total: 155288, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:22,869 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_sun.bin for category sun...
INFO:__main__:Indexing drawings from ./data/full_binary_sun.bin for category sun...


Indexing full_binary_sun.bin:   0%|          | 0.00/15.9M [00:00<?, ?B/s]

2025-09-01 10:39:24,365 - INFO - ipython-input-2928541081 - Saved index with 133781 entries to ./data/full_binary_sun.idx
INFO:__main__:Saved index with 133781 entries to ./data/full_binary_sun.idx
2025-09-01 10:39:24,372 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 133781 drawings for sun. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 133781 drawings for sun. Cache capacity: 50000 items.
2025-09-01 10:39:24,374 - INFO - ipython-input-2988690619 - Category: sun, Total: 133781, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: sun, Total: 133781, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:24,376 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_tree.bin for category tree...
INFO:__main__:Indexing drawings from ./data/full_binary_tree.bin for category tree...


Indexing full_binary_tree.bin:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

2025-09-01 10:39:26,025 - INFO - ipython-input-2928541081 - Saved index with 144721 entries to ./data/full_binary_tree.idx
INFO:__main__:Saved index with 144721 entries to ./data/full_binary_tree.idx
2025-09-01 10:39:26,026 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 144721 drawings for tree. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 144721 drawings for tree. Cache capacity: 50000 items.
2025-09-01 10:39:26,028 - INFO - ipython-input-2988690619 - Category: tree, Total: 144721, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: tree, Total: 144721, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:26,030 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_umbrella.bin for category umbrella...
INFO:__main__:Indexing drawings from ./data/full_binary_umbrella.bin for category umbrella...


Indexing full_binary_umbrella.bin:   0%|          | 0.00/11.5M [00:00<?, ?B/s]

2025-09-01 10:39:26,984 - INFO - ipython-input-2928541081 - Saved index with 124084 entries to ./data/full_binary_umbrella.idx
INFO:__main__:Saved index with 124084 entries to ./data/full_binary_umbrella.idx
2025-09-01 10:39:26,985 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 124084 drawings for umbrella. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 124084 drawings for umbrella. Cache capacity: 50000 items.
2025-09-01 10:39:26,987 - INFO - ipython-input-2988690619 - Category: umbrella, Total: 124084, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: umbrella, Total: 124084, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:26,989 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_van.bin for category van...
INFO:__main__:Indexing drawings from ./data/full_binary_van.bin for category van...


Indexing full_binary_van.bin:   0%|          | 0.00/21.7M [00:00<?, ?B/s]

2025-09-01 10:39:28,280 - INFO - ipython-input-2928541081 - Saved index with 165909 entries to ./data/full_binary_van.idx
INFO:__main__:Saved index with 165909 entries to ./data/full_binary_van.idx
2025-09-01 10:39:28,281 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 165909 drawings for van. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 165909 drawings for van. Cache capacity: 50000 items.
2025-09-01 10:39:28,283 - INFO - ipython-input-2988690619 - Category: van, Total: 165909, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: van, Total: 165909, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:28,285 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_cake.bin for category cake...
INFO:__main__:Indexing drawings from ./data/full_binary_cake.bin for category cake...


Indexing full_binary_cake.bin:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

2025-09-01 10:39:29,612 - INFO - ipython-input-2928541081 - Saved index with 124905 entries to ./data/full_binary_cake.idx
INFO:__main__:Saved index with 124905 entries to ./data/full_binary_cake.idx
2025-09-01 10:39:29,613 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 124905 drawings for cake. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 124905 drawings for cake. Cache capacity: 50000 items.
2025-09-01 10:39:29,615 - INFO - ipython-input-2988690619 - Category: cake, Total: 124905, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: cake, Total: 124905, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:29,617 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_airplane.bin for category airplane...
INFO:__main__:Indexing drawings from ./data/full_binary_airplane.bin for category airplane...


Indexing full_binary_airplane.bin:   0%|          | 0.00/15.7M [00:00<?, ?B/s]

2025-09-01 10:39:30,697 - INFO - ipython-input-2928541081 - Saved index with 151623 entries to ./data/full_binary_airplane.idx
INFO:__main__:Saved index with 151623 entries to ./data/full_binary_airplane.idx
2025-09-01 10:39:30,699 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 151623 drawings for airplane. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 151623 drawings for airplane. Cache capacity: 50000 items.
2025-09-01 10:39:30,701 - INFO - ipython-input-2988690619 - Category: airplane, Total: 151623, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: airplane, Total: 151623, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:30,704 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_ant.bin for category ant...
INFO:__main__:Indexing drawings from ./data/full_binary_ant.bin for category ant...


Indexing full_binary_ant.bin:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

2025-09-01 10:39:32,060 - INFO - ipython-input-2928541081 - Saved index with 124612 entries to ./data/full_binary_ant.idx
INFO:__main__:Saved index with 124612 entries to ./data/full_binary_ant.idx
2025-09-01 10:39:32,061 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 124612 drawings for ant. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 124612 drawings for ant. Cache capacity: 50000 items.
2025-09-01 10:39:32,063 - INFO - ipython-input-2988690619 - Category: ant, Total: 124612, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: ant, Total: 124612, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:32,066 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_banana.bin for category banana...
INFO:__main__:Indexing drawings from ./data/full_binary_banana.bin for category banana...


Indexing full_binary_banana.bin:   0%|          | 0.00/25.1M [00:00<?, ?B/s]

2025-09-01 10:39:33,642 - INFO - ipython-input-2928541081 - Saved index with 307936 entries to ./data/full_binary_banana.idx
INFO:__main__:Saved index with 307936 entries to ./data/full_binary_banana.idx
2025-09-01 10:39:33,643 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 307936 drawings for banana. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 307936 drawings for banana. Cache capacity: 50000 items.
2025-09-01 10:39:33,644 - INFO - ipython-input-2988690619 - Category: banana, Total: 307936, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: banana, Total: 307936, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:33,646 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_bed.bin for category bed...
INFO:__main__:Indexing drawings from ./data/full_binary_bed.bin for category bed...


Indexing full_binary_bed.bin:   0%|          | 0.00/11.1M [00:00<?, ?B/s]

2025-09-01 10:39:34,552 - INFO - ipython-input-2928541081 - Saved index with 113862 entries to ./data/full_binary_bed.idx
INFO:__main__:Saved index with 113862 entries to ./data/full_binary_bed.idx
2025-09-01 10:39:34,553 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 113862 drawings for bed. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 113862 drawings for bed. Cache capacity: 50000 items.
2025-09-01 10:39:34,555 - INFO - ipython-input-2988690619 - Category: bed, Total: 113862, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: bed, Total: 113862, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:34,557 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_bee.bin for category bee...
INFO:__main__:Indexing drawings from ./data/full_binary_bee.bin for category bee...


Indexing full_binary_bee.bin:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

2025-09-01 10:39:35,894 - INFO - ipython-input-2928541081 - Saved index with 120890 entries to ./data/full_binary_bee.idx
INFO:__main__:Saved index with 120890 entries to ./data/full_binary_bee.idx
2025-09-01 10:39:35,896 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 120890 drawings for bee. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 120890 drawings for bee. Cache capacity: 50000 items.
2025-09-01 10:39:35,898 - INFO - ipython-input-2988690619 - Category: bee, Total: 120890, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: bee, Total: 120890, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:35,901 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_bicycle.bin for category bicycle...
INFO:__main__:Indexing drawings from ./data/full_binary_bicycle.bin for category bicycle...


Indexing full_binary_bicycle.bin:   0%|          | 0.00/18.4M [00:00<?, ?B/s]

2025-09-01 10:39:37,748 - INFO - ipython-input-2928541081 - Saved index with 126527 entries to ./data/full_binary_bicycle.idx
INFO:__main__:Saved index with 126527 entries to ./data/full_binary_bicycle.idx
2025-09-01 10:39:37,749 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 126527 drawings for bicycle. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 126527 drawings for bicycle. Cache capacity: 50000 items.
2025-09-01 10:39:37,750 - INFO - ipython-input-2988690619 - Category: bicycle, Total: 126527, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: bicycle, Total: 126527, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:37,752 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_bird.bin for category bird...
INFO:__main__:Indexing drawings from ./data/full_binary_bird.bin for category bird...


Indexing full_binary_bird.bin:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

2025-09-01 10:39:39,490 - INFO - ipython-input-2928541081 - Saved index with 133572 entries to ./data/full_binary_bird.idx
INFO:__main__:Saved index with 133572 entries to ./data/full_binary_bird.idx
2025-09-01 10:39:39,491 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 133572 drawings for bird. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 133572 drawings for bird. Cache capacity: 50000 items.
2025-09-01 10:39:39,494 - INFO - ipython-input-2988690619 - Category: bird, Total: 133572, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: bird, Total: 133572, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:39,496 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_book.bin for category book...
INFO:__main__:Indexing drawings from ./data/full_binary_book.bin for category book...


Indexing full_binary_book.bin:   0%|          | 0.00/14.1M [00:00<?, ?B/s]

2025-09-01 10:39:40,607 - INFO - ipython-input-2928541081 - Saved index with 119364 entries to ./data/full_binary_book.idx
INFO:__main__:Saved index with 119364 entries to ./data/full_binary_book.idx
2025-09-01 10:39:40,609 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 119364 drawings for book. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 119364 drawings for book. Cache capacity: 50000 items.
2025-09-01 10:39:40,612 - INFO - ipython-input-2988690619 - Category: book, Total: 119364, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: book, Total: 119364, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:40,614 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_bread.bin for category bread...
INFO:__main__:Indexing drawings from ./data/full_binary_bread.bin for category bread...


Indexing full_binary_bread.bin:   0%|          | 0.00/9.74M [00:00<?, ?B/s]

2025-09-01 10:39:41,296 - INFO - ipython-input-2928541081 - Saved index with 120570 entries to ./data/full_binary_bread.idx
INFO:__main__:Saved index with 120570 entries to ./data/full_binary_bread.idx
2025-09-01 10:39:41,299 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 120570 drawings for bread. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 120570 drawings for bread. Cache capacity: 50000 items.
2025-09-01 10:39:41,301 - INFO - ipython-input-2988690619 - Category: bread, Total: 120570, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: bread, Total: 120570, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:41,304 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_bus.bin for category bus...
INFO:__main__:Indexing drawings from ./data/full_binary_bus.bin for category bus...


Indexing full_binary_bus.bin:   0%|          | 0.00/25.3M [00:00<?, ?B/s]

2025-09-01 10:39:42,917 - INFO - ipython-input-2928541081 - Saved index with 166208 entries to ./data/full_binary_bus.idx
INFO:__main__:Saved index with 166208 entries to ./data/full_binary_bus.idx
2025-09-01 10:39:42,918 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 166208 drawings for bus. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 166208 drawings for bus. Cache capacity: 50000 items.
2025-09-01 10:39:42,919 - INFO - ipython-input-2988690619 - Category: bus, Total: 166208, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: bus, Total: 166208, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:42,922 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_elbow.bin for category elbow...
INFO:__main__:Indexing drawings from ./data/full_binary_elbow.bin for category elbow...


Indexing full_binary_elbow.bin:   0%|          | 0.00/9.56M [00:00<?, ?B/s]

2025-09-01 10:39:43,650 - INFO - ipython-input-2928541081 - Saved index with 126253 entries to ./data/full_binary_elbow.idx
INFO:__main__:Saved index with 126253 entries to ./data/full_binary_elbow.idx
2025-09-01 10:39:43,651 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 126253 drawings for elbow. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 126253 drawings for elbow. Cache capacity: 50000 items.
2025-09-01 10:39:43,652 - INFO - ipython-input-2988690619 - Category: elbow, Total: 126253, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: elbow, Total: 126253, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:43,655 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_ear.bin for category ear...
INFO:__main__:Indexing drawings from ./data/full_binary_ear.bin for category ear...


Indexing full_binary_ear.bin:   0%|          | 0.00/10.2M [00:00<?, ?B/s]

2025-09-01 10:39:44,306 - INFO - ipython-input-2928541081 - Saved index with 122897 entries to ./data/full_binary_ear.idx
INFO:__main__:Saved index with 122897 entries to ./data/full_binary_ear.idx
2025-09-01 10:39:44,308 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 122897 drawings for ear. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 122897 drawings for ear. Cache capacity: 50000 items.
2025-09-01 10:39:44,309 - INFO - ipython-input-2988690619 - Category: ear, Total: 122897, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: ear, Total: 122897, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:44,315 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_camera.bin for category camera...
INFO:__main__:Indexing drawings from ./data/full_binary_camera.bin for category camera...


Indexing full_binary_camera.bin:   0%|          | 0.00/13.6M [00:00<?, ?B/s]

2025-09-01 10:39:45,160 - INFO - ipython-input-2928541081 - Saved index with 128772 entries to ./data/full_binary_camera.idx
INFO:__main__:Saved index with 128772 entries to ./data/full_binary_camera.idx
2025-09-01 10:39:45,163 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 128772 drawings for camera. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 128772 drawings for camera. Cache capacity: 50000 items.
2025-09-01 10:39:45,164 - INFO - ipython-input-2988690619 - Category: camera, Total: 128772, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: camera, Total: 128772, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:45,167 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_car.bin for category car...
INFO:__main__:Indexing drawings from ./data/full_binary_car.bin for category car...


Indexing full_binary_car.bin:   0%|          | 0.00/24.4M [00:00<?, ?B/s]

2025-09-01 10:39:46,562 - INFO - ipython-input-2928541081 - Saved index with 182764 entries to ./data/full_binary_car.idx
INFO:__main__:Saved index with 182764 entries to ./data/full_binary_car.idx
2025-09-01 10:39:46,563 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 182764 drawings for car. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 182764 drawings for car. Cache capacity: 50000 items.
2025-09-01 10:39:46,564 - INFO - ipython-input-2988690619 - Category: car, Total: 182764, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: car, Total: 182764, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:46,570 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_chair.bin for category chair...
INFO:__main__:Indexing drawings from ./data/full_binary_chair.bin for category chair...


Indexing full_binary_chair.bin:   0%|          | 0.00/17.6M [00:00<?, ?B/s]

2025-09-01 10:39:48,167 - INFO - ipython-input-2928541081 - Saved index with 222706 entries to ./data/full_binary_chair.idx
INFO:__main__:Saved index with 222706 entries to ./data/full_binary_chair.idx
2025-09-01 10:39:48,168 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 222706 drawings for chair. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 222706 drawings for chair. Cache capacity: 50000 items.
2025-09-01 10:39:48,169 - INFO - ipython-input-2988690619 - Category: chair, Total: 222706, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: chair, Total: 222706, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:48,172 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_clock.bin for category clock...
INFO:__main__:Indexing drawings from ./data/full_binary_clock.bin for category clock...


Indexing full_binary_clock.bin:   0%|          | 0.00/12.9M [00:00<?, ?B/s]

2025-09-01 10:39:49,048 - INFO - ipython-input-2928541081 - Saved index with 120536 entries to ./data/full_binary_clock.idx
INFO:__main__:Saved index with 120536 entries to ./data/full_binary_clock.idx
2025-09-01 10:39:49,050 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 120536 drawings for clock. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 120536 drawings for clock. Cache capacity: 50000 items.
2025-09-01 10:39:49,051 - INFO - ipython-input-2988690619 - Category: clock, Total: 120536, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: clock, Total: 120536, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:49,054 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_cloud.bin for category cloud...
INFO:__main__:Indexing drawings from ./data/full_binary_cloud.bin for category cloud...


Indexing full_binary_cloud.bin:   0%|          | 0.00/13.3M [00:00<?, ?B/s]

2025-09-01 10:39:50,029 - INFO - ipython-input-2928541081 - Saved index with 120265 entries to ./data/full_binary_cloud.idx
INFO:__main__:Saved index with 120265 entries to ./data/full_binary_cloud.idx
2025-09-01 10:39:50,030 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 120265 drawings for cloud. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 120265 drawings for cloud. Cache capacity: 50000 items.
2025-09-01 10:39:50,032 - INFO - ipython-input-2988690619 - Category: cloud, Total: 120265, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: cloud, Total: 120265, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:50,034 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_hand.bin for category hand...
INFO:__main__:Indexing drawings from ./data/full_binary_hand.bin for category hand...


Indexing full_binary_hand.bin:   0%|          | 0.00/31.2M [00:00<?, ?B/s]

2025-09-01 10:39:52,126 - INFO - ipython-input-2928541081 - Saved index with 291773 entries to ./data/full_binary_hand.idx
INFO:__main__:Saved index with 291773 entries to ./data/full_binary_hand.idx
2025-09-01 10:39:52,127 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 291773 drawings for hand. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 291773 drawings for hand. Cache capacity: 50000 items.
2025-09-01 10:39:52,128 - INFO - ipython-input-2988690619 - Category: hand, Total: 291773, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: hand, Total: 291773, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:52,131 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_computer.bin for category computer...
INFO:__main__:Indexing drawings from ./data/full_binary_computer.bin for category computer...


Indexing full_binary_computer.bin:   0%|          | 0.00/14.3M [00:00<?, ?B/s]

2025-09-01 10:39:53,172 - INFO - ipython-input-2928541081 - Saved index with 123885 entries to ./data/full_binary_computer.idx
INFO:__main__:Saved index with 123885 entries to ./data/full_binary_computer.idx
2025-09-01 10:39:53,174 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 123885 drawings for computer. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 123885 drawings for computer. Cache capacity: 50000 items.
2025-09-01 10:39:53,175 - INFO - ipython-input-2988690619 - Category: computer, Total: 123885, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: computer, Total: 123885, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:53,178 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_cookie.bin for category cookie...
INFO:__main__:Indexing drawings from ./data/full_binary_cookie.bin for category cookie...


Indexing full_binary_cookie.bin:   0%|          | 0.00/20.6M [00:00<?, ?B/s]

2025-09-01 10:39:54,474 - INFO - ipython-input-2928541081 - Saved index with 131353 entries to ./data/full_binary_cookie.idx
INFO:__main__:Saved index with 131353 entries to ./data/full_binary_cookie.idx
2025-09-01 10:39:54,476 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 131353 drawings for cookie. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 131353 drawings for cookie. Cache capacity: 50000 items.
2025-09-01 10:39:54,477 - INFO - ipython-input-2988690619 - Category: cookie, Total: 131353, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: cookie, Total: 131353, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:54,479 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_cow.bin for category cow...
INFO:__main__:Indexing drawings from ./data/full_binary_cow.bin for category cow...


Indexing full_binary_cow.bin:   0%|          | 0.00/25.6M [00:00<?, ?B/s]

2025-09-01 10:39:56,012 - INFO - ipython-input-2928541081 - Saved index with 123083 entries to ./data/full_binary_cow.idx
INFO:__main__:Saved index with 123083 entries to ./data/full_binary_cow.idx
2025-09-01 10:39:56,014 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 123083 drawings for cow. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 123083 drawings for cow. Cache capacity: 50000 items.
2025-09-01 10:39:56,016 - INFO - ipython-input-2988690619 - Category: cow, Total: 123083, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: cow, Total: 123083, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:56,019 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_crayon.bin for category crayon...
INFO:__main__:Indexing drawings from ./data/full_binary_crayon.bin for category crayon...


Indexing full_binary_crayon.bin:   0%|          | 0.00/11.3M [00:00<?, ?B/s]

2025-09-01 10:39:56,914 - INFO - ipython-input-2928541081 - Saved index with 129953 entries to ./data/full_binary_crayon.idx
INFO:__main__:Saved index with 129953 entries to ./data/full_binary_crayon.idx
2025-09-01 10:39:56,916 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 129953 drawings for crayon. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 129953 drawings for crayon. Cache capacity: 50000 items.
2025-09-01 10:39:56,918 - INFO - ipython-input-2988690619 - Category: crayon, Total: 129953, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: crayon, Total: 129953, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:56,920 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_cup.bin for category cup...
INFO:__main__:Indexing drawings from ./data/full_binary_cup.bin for category cup...


Indexing full_binary_cup.bin:   0%|          | 0.00/12.7M [00:00<?, ?B/s]

2025-09-01 10:39:57,745 - INFO - ipython-input-2928541081 - Saved index with 130721 entries to ./data/full_binary_cup.idx
INFO:__main__:Saved index with 130721 entries to ./data/full_binary_cup.idx
2025-09-01 10:39:57,748 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 130721 drawings for cup. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 130721 drawings for cup. Cache capacity: 50000 items.
2025-09-01 10:39:57,749 - INFO - ipython-input-2988690619 - Category: cup, Total: 130721, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: cup, Total: 130721, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:57,752 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_eraser.bin for category eraser...
INFO:__main__:Indexing drawings from ./data/full_binary_eraser.bin for category eraser...


Indexing full_binary_eraser.bin:   0%|          | 0.00/10.9M [00:00<?, ?B/s]

2025-09-01 10:39:58,594 - INFO - ipython-input-2928541081 - Saved index with 118339 entries to ./data/full_binary_eraser.idx
INFO:__main__:Saved index with 118339 entries to ./data/full_binary_eraser.idx
2025-09-01 10:39:58,596 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 118339 drawings for eraser. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 118339 drawings for eraser. Cache capacity: 50000 items.
2025-09-01 10:39:58,597 - INFO - ipython-input-2988690619 - Category: eraser, Total: 118339, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: eraser, Total: 118339, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:58,600 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_carrot.bin for category carrot...
INFO:__main__:Indexing drawings from ./data/full_binary_carrot.bin for category carrot...


Indexing full_binary_carrot.bin:   0%|          | 0.00/14.1M [00:00<?, ?B/s]

2025-09-01 10:39:59,689 - INFO - ipython-input-2928541081 - Saved index with 132459 entries to ./data/full_binary_carrot.idx
INFO:__main__:Saved index with 132459 entries to ./data/full_binary_carrot.idx
2025-09-01 10:39:59,691 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 132459 drawings for carrot. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 132459 drawings for carrot. Cache capacity: 50000 items.
2025-09-01 10:39:59,693 - INFO - ipython-input-2988690619 - Category: carrot, Total: 132459, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: carrot, Total: 132459, Train: 100, Val: 20, Test: 20
2025-09-01 10:39:59,697 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_drums.bin for category drums...
INFO:__main__:Indexing drawings from ./data/full_binary_drums.bin for category drums...


Indexing full_binary_drums.bin:   0%|          | 0.00/19.7M [00:00<?, ?B/s]

2025-09-01 10:40:01,015 - INFO - ipython-input-2928541081 - Saved index with 137299 entries to ./data/full_binary_drums.idx
INFO:__main__:Saved index with 137299 entries to ./data/full_binary_drums.idx
2025-09-01 10:40:01,017 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 137299 drawings for drums. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 137299 drawings for drums. Cache capacity: 50000 items.
2025-09-01 10:40:01,018 - INFO - ipython-input-2988690619 - Category: drums, Total: 137299, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: drums, Total: 137299, Train: 100, Val: 20, Test: 20
2025-09-01 10:40:01,021 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_eye.bin for category eye...
INFO:__main__:Indexing drawings from ./data/full_binary_eye.bin for category eye...


Indexing full_binary_eye.bin:   0%|          | 0.00/16.9M [00:00<?, ?B/s]

2025-09-01 10:40:02,284 - INFO - ipython-input-2928541081 - Saved index with 125888 entries to ./data/full_binary_eye.idx
INFO:__main__:Saved index with 125888 entries to ./data/full_binary_eye.idx
2025-09-01 10:40:02,285 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 125888 drawings for eye. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 125888 drawings for eye. Cache capacity: 50000 items.
2025-09-01 10:40:02,287 - INFO - ipython-input-2988690619 - Category: eye, Total: 125888, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: eye, Total: 125888, Train: 100, Val: 20, Test: 20
2025-09-01 10:40:02,289 - INFO - ipython-input-2928541081 - Indexing drawings from ./data/full_binary_knife.bin for category knife...
INFO:__main__:Indexing drawings from ./data/full_binary_knife.bin for category knife...


Indexing full_binary_knife.bin:   0%|          | 0.00/12.6M [00:00<?, ?B/s]

2025-09-01 10:40:03,939 - INFO - ipython-input-2928541081 - Saved index with 172656 entries to ./data/full_binary_knife.idx
INFO:__main__:Saved index with 172656 entries to ./data/full_binary_knife.idx
2025-09-01 10:40:03,942 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 172656 drawings for knife. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 172656 drawings for knife. Cache capacity: 50000 items.
2025-09-01 10:40:03,943 - INFO - ipython-input-2988690619 - Category: knife, Total: 172656, Train: 100, Val: 20, Test: 20
INFO:__main__:Category: knife, Total: 172656, Train: 100, Val: 20, Test: 20
2025-09-01 10:40:03,955 - INFO - ipython-input-2988690619 - Category counts saved to quickdraw_category_counts.csv
INFO:__main__:Category counts saved to quickdraw_category_counts.csv


##  Data Preparation with Augmentation

In [10]:
# --- Data Preparation with Augmentation ---
def get_augmented_quickdraw_data(categories, num_train_per_cat, num_test_per_cat, data_root):
    """"
    Enhanced version of get_quickdraw_data that creates train/validation/test splits
    and applies data augmentation to the training set.
    """
    logger.info(f"Loading augmented QuickDraw data for {len(categories)} categories...")

    # Create standard transforms
    base_transform = T.Compose([
        T.Resize((224, 224)), # Resize to 224x224 for ViT compatibility
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Augmented transform for training
    train_transform = T.Compose([
        T.Grayscale(num_output_channels=3),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomRotation(15),
        T.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        T.ColorJitter(brightness=0.2, contrast=0.2),
        T.Resize((224, 224)), # Resize to 224x224 for ViT compatibility
        base_transform
    ])

    # Simple transform for validation/testing
    test_transform = T.Compose([
        T.Grayscale(num_output_channels=3),
        T.Resize((224, 224)), # Resize to 224x224 for ViT compatibility
        base_transform
    ])

    all_train_datasets = []
    all_val_datasets = []
    all_test_datasets = []

    for category_idx, category_name in enumerate(categories):
        try:
            logger.info(f"Loading QuickDraw category: {category_name} from local binary files...")
            full_category_dataset = QuickDrawBinaryDataset(
                root=data_root,
                category=category_name,
                transform=None,  # No transform yet, will be applied per-split
                cache_size=QUICKDRAW_CACHE_SIZE
            )

            if len(full_category_dataset) == 0:
                logger.warning(f"No data loaded for category {category_name}. Skipping.")
                continue

            # Calculate actual samples to use (limited by available data)
            actual_num_train = min(num_train_per_cat, int(len(full_category_dataset) * 0.7))
            remaining_samples = len(full_category_dataset) - actual_num_train

            # Split remaining samples between validation and test
            actual_num_val = min(int(num_train_per_cat * VALIDATION_SPLIT), int(remaining_samples * 0.5))
            actual_num_test = min(NUM_TEST_SAMPLES_PER_CATEGORY, remaining_samples - actual_num_val)

            actual_num_train = int(actual_num_train)
            actual_num_val = int(actual_num_val)
            actual_num_test = int(actual_num_test)


            logger.info(f"Category {category_name}: {actual_num_train} train, {actual_num_val} val, {actual_num_test} test")

            # Ensure we have enough samples
            if actual_num_train == 0 or actual_num_val == 0 or actual_num_test == 0:
                logger.warning(f"Not enough samples in {category_name} for desired splits. Skipping.")
                continue

            # Create indices for random splits
            indices = list(range(len(full_category_dataset)))
            random.shuffle(indices)

            train_indices = indices[:actual_num_train]
            val_indices = indices[actual_num_train:actual_num_train + actual_num_val]
            test_indices = indices[actual_num_train + actual_num_val:actual_num_train + actual_num_val + actual_num_test]

            # Create labeled datasets with appropriate transforms
            train_subset = TransformedSubset(full_category_dataset, train_indices, train_transform, category_idx)
            val_subset = TransformedSubset(full_category_dataset, val_indices, test_transform, category_idx)
            test_subset = TransformedSubset(full_category_dataset, test_indices, test_transform, category_idx)

            all_train_datasets.append(train_subset)
            all_val_datasets.append(val_subset)
            all_test_datasets.append(test_subset)

        except Exception as e:
            logger.error(f"Error loading category {category_name}: {e}")
            continue

    # Ensure we have data for at least some categories
    if not all_train_datasets or not all_val_datasets or not all_test_datasets:
        if not any(all_train_datasets) and not any(all_val_datasets) and not any(all_test_datasets):
            raise RuntimeError("No QuickDraw data could be loaded for any category. Aborting.")
        else:
            logger.warning("Some categories failed to load, proceeding with available data.")


    # Combine datasets across categories
    train_dataset = ConcatDataset(all_train_datasets)
    val_dataset = ConcatDataset(all_val_datasets)
    test_dataset = ConcatDataset(all_test_datasets)


    logger.info(f"Created datasets with {len(train_dataset)} training samples, "
               f"{len(val_dataset)} validation samples, and {len(test_dataset)} test samples.")


    return train_dataset, val_dataset, test_dataset

# Helper class for applying transforms during subset creation
class TransformedSubset(Dataset):
    """
    A subset of a dataset with a transform that's applied on-the-fly.
    Also assigns a fixed class label for classification tasks.
    """
    def __init__(self, dataset, indices, transform, label):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform
        self.label = label

    def __getitem__(self, idx):
        try:
            image, _ = self.dataset[self.indices[idx]]
            if self.transform:
                image = self.transform(image)
            return image, self.label
        except Exception as e:
            # Fallback to a default/placeholder image or retry
            logger.warning(f"Error loading image at index {idx}: {e}")
            # Create a placeholder image (black square)
            placeholder = torch.zeros(3, 224, 224)
            return placeholder, self.label


    def __len__(self):
        return len(self.indices)

# Model Classes

In [11]:
!pip install -U transformers

Collecting transformers
  Downloading transformers-4.56.0-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Downloading transformers-4.56.0-py3-none-any.whl (11.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m131.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers-0.22.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m118.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.21.4
    Uninstalling tokenizers-0.21.4:
      Successfully uninstalled tokenize

In [26]:
# --- Model Wrapper Architecture for Ensemble ---
class ModelWrapper(nn.Module):
    """Base wrapper for models to ensure consistent interface in ensemble"""
    def __init__(self, model=None, num_classes=len(QUICKDRAW_CATEGORIES)):
        super().__init__()
        self.model = model
        self.num_classes = num_classes

    def forward(self, x):
        return self.model(x)

    @classmethod
    def load_from_checkpoint(cls, model_path, device=DEVICE):
        """Load model from checkpoint with proper error handling"""
        try:
            instance = cls()
            state_dict = torch.load(model_path, map_location=device)
            instance.model.load_state_dict(state_dict)
            instance.model.to(device)
            instance.model.eval()
            return instance
        except Exception as e:
            logger.error(f"Failed to load model from {model_path}: {e}")
            return None

class ShuffleNetV2Wrapper(ModelWrapper):
    def __init__(self, weights=None, num_classes=len(QUICKDRAW_CATEGORIES)):
        super().__init__(None, num_classes)
        self.model = models.shufflenet_v2_x0_5(weights=weights)
        in_features = self.model.fc.in_features

        # Replace with optimized classifier for better accuracy
        self.model.fc = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    @classmethod
    def load_from_checkpoint(cls, model_path, device=DEVICE):
        instance = cls(weights=None)
        try:
            state_dict = torch.load(model_path, map_location=device)
            instance.model.load_state_dict(state_dict)
            instance.model.to(device)
            instance.model.eval()
            return instance
        except Exception as e:
            logger.error(f"Failed to load ShuffleNetV2 from {model_path}: {e}")
            return None

class MobileNetV3Wrapper(ModelWrapper):
    def __init__(self, weights=None, num_classes=len(QUICKDRAW_CATEGORIES)):
        super().__init__(None, num_classes)
        self.model = models.mobilenet_v3_small(weights=weights)
        in_features = self.model.classifier[0].in_features

        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    @classmethod
    def load_from_checkpoint(cls, model_path, device=DEVICE):
        instance = cls(weights=None)
        try:
            state_dict = torch.load(model_path, map_location=device)
            instance.model.load_state_dict(state_dict)
            instance.model.to(device)
            instance.model.eval()
            return instance
        except Exception as e:
            logger.error(f"Failed to load MobileNetV3 from {model_path}: {e}")
            return None

class SqueezeNetWrapper(ModelWrapper):
    def __init__(self, weights=None, num_classes=len(QUICKDRAW_CATEGORIES)):
        super().__init__(None, num_classes)
        self.model = models.squeezenet1_1(weights=weights)
        in_channels = self.model.classifier[1].in_channels

        # Use the specialized SqueezeNet classifier format
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Conv2d(in_channels, 512, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(512, kernel_size=1, out_channels=num_classes)
        )

    @classmethod
    def load_from_checkpoint(cls, model_path, device=DEVICE):
        instance = cls(weights=None)
        try:
            state_dict = torch.load(model_path, map_location=device)
            instance.model.load_state_dict(state_dict)
            instance.model.to(device)
            instance.model.eval()
            return instance
        except Exception as e:
            logger.error(f"Failed to load SqueezeNet from {model_path}: {e}")
            return None

# Add a ModelWrapper for Vision Transformer (ViT)
from transformers import AutoModelForImageClassification

class ViTWrapper(ModelWrapper):
    def __init__(self, model_name="google/vit-base-patch16-224", num_classes=len(QUICKDRAW_CATEGORIES)):
        super().__init__(None, num_classes)
        # Load the pre-trained ViT model
        self.model = AutoModelForImageClassification.from_pretrained(model_name)

        # Modify the classifier head for the new number of classes
        # The exact structure might vary depending on the specific ViT model
        # This is a common pattern for many ViT models
        if hasattr(self.model, 'classifier'):
             num_ftrs = self.model.classifier.in_features
             self.model.classifier = nn.Linear(num_ftrs, num_classes)
        elif hasattr(self.model, 'head'):
             num_ftrs = self.model.head.in_features
             self.model.head = nn.Linear(num_ftrs, num_classes)
        else:
             logger.warning(f"Could not find a standard classifier head in {model_name}. Manual adaptation may be needed.")


    @classmethod
    def load_from_checkpoint(cls, model_path, device=DEVICE):
        instance = cls(model_name="google/vit-base-patch16-224") # Specify the model name here
        try:
            state_dict = torch.load(model_path, map_location=device)
            instance.model.load_state_dict(state_dict)
            instance.model.to(device)
            instance.model.eval()
            return instance
        except Exception as e:
            logger.error(f"Failed to load ViT from {model_path}: {e}")
            return None

# Add a ModelWrapper for MultimodalEmbeddingModel
from google.cloud import aiplatform
# Corrected import: Import Image from vertexai.vision_models
from vertexai.vision_models import Image as VertexAiImage, MultiModalEmbeddingModel
# Import Image from PIL and BytesIO for converting PIL to bytes
from PIL import Image
from io import BytesIO

class MultimodalEmbeddingWrapper(ModelWrapper):
    def __init__(self, project_id, location, num_classes=len(QUICKDRAW_CATEGORIES), categories=QUICKDRAW_CATEGORIES, num_categories=len(QUICKDRAW_CATEGORIES), num_images_per_category=NUM_TRAIN_SAMPLES_PER_CATEGORY):
        super().__init__(None, num_classes)
        self.project_id = project_id
        self.location = location
        self.categories = categories
        # Add the new attributes
        self.num_categories = num_categories
        self.num_images_per_category = num_images_per_category


        # Initialize Vertex AI and load the model
        try:
            aiplatform.init(project=self.project_id, location=self.location)
            self.model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
            logger.info("Successfully initialized Vertex AI and loaded MultimodalEmbeddingModel.")
        except Exception as e:
            logger.error(f"Failed to initialize Vertex AI or load MultimodalEmbeddingModel: {e}")
            self.model = None # Ensure model is None if loading fails

        # Pre-compute text embeddings for categories
        self.category_text_embeddings = self._get_category_text_embeddings()

        # Add a linear layer to classify image embeddings based on text embeddings
        # The MultimodalEmbeddingModel outputs a fixed dimension (1408)
        if self.category_text_embeddings is not None and self.model is not None:
             # Compute the average text embedding to get the dimension
             avg_embedding = np.mean(self.category_text_embeddings, axis=0)
             input_dim = len(avg_embedding) # Should be 1408
             # Use cosine similarity for classification
             # The text embeddings act as the "weights" for classification
             # We don't need a separate linear layer for this approach
             self.classifier = None # No traditional classifier needed for this method
             logger.info(f"Initialized MultimodalEmbeddingWrapper for similarity-based classification.")
        else:
             self.classifier = None
             logger.warning("Could not initialize MultimodalEmbeddingWrapper for similarity-based classification.")


    def _get_category_text_embeddings(self):
        """Generates text embeddings for each category name."""
        if not self.model:
            logger.warning("MultimodalEmbeddingModel not loaded, cannot get text embeddings.")
            return None

        text_embeddings = []
        logger.info("Generating text embeddings for categories...")
        try:
            for category in tqdm(self.categories, desc="Generating text embeddings"):
                 # The context_text is optional but can improve quality for ambiguous terms
                 embeddings = self.model.get_embeddings(
                     contextual_text=category,
                     dimension=1408 # A required, fixed dimension size for this model
                 )
                 if embeddings and embeddings.text_embedding:
                     text_embeddings.append(embeddings.text_embedding)
                 else:
                     logger.warning(f"Could not get text embedding for category: {category}")

            if not text_embeddings:
                 logger.warning("No text embeddings were generated.")
                 return None

            return np.array(text_embeddings) # Convert to numpy array
        except Exception as e:
             logger.error(f"Error generating text embeddings: {e}")
             return None

    def calculate_cosine_similarity(self, vec1, vec2):
        """Calculates the cosine similarity between two vectors."""
        # Ensure inputs are numpy arrays for dot product and norm
        vec1 = np.asarray(vec1)
        vec2 = np.asarray(vec2)
        # Add a small epsilon to the denominator to avoid division by zero
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)

    def forward(self, images):
        """
        Generates image embeddings and classifies them using cosine similarity
        with pre-computed text embeddings.
        """
        if self.model is None or self.category_text_embeddings is None:
            logger.error("MultimodalEmbeddingModel or text embeddings not initialized.")
            # Return dummy output
            return torch.zeros(images.size(0), self.num_classes, device=images.device)

        # The MultimodalEmbeddingModel expects VertexAiImage objects initialized with bytes
        # Convert torch tensor batch to list of PIL Images, then to bytes, then to VertexAiImage
        vertex_ai_images = []
        for img_tensor in images:
             # Permute dimensions from C, H, W to H, W, C for numpy conversion
             # Convert to numpy array and scale to 0-255
             img_np = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
             # Convert numpy array to PIL Image
             pil_image = Image.fromarray(img_np)
             # Convert PIL image to bytes
             buffer = BytesIO()
             # Assuming the input images are grayscale sketches, save as PNG
             pil_image.save(buffer, format="PNG")
             image_bytes = buffer.getvalue()
             # Create VertexAiImage object from bytes
             vertex_ai_images.append(VertexAiImage(image_bytes=image_bytes))


        # Get image embeddings from the MultimodalEmbeddingModel
        try:
            # Process images one by one as a workaround for the batch issue
            all_image_embeddings = []
            # Use tqdm for progress tracking
            for i in tqdm(range(len(vertex_ai_images)), desc="Generating image embeddings and classifying", leave=False):
                img_ai = vertex_ai_images[i]
                embeddings_response = self.model.get_embeddings(
                    image=img_ai, # Pass single VertexAiImage object
                    dimension=1408
                )

                if embeddings_response and embeddings_response.image_embedding:
                    all_image_embeddings.append(embeddings_response.image_embedding)
                else:
                    logger.warning(f"Could not get image embedding for image at index {i}. Skipping.")
                    # Append a zero vector or handle appropriately if an embedding is missing
                    # For now, we'll skip and this might cause a shape mismatch later if not all embeddings are generated.
                    # A better approach might be to append a zero vector of the expected dimension.
                    # Let's append a zero vector for now to maintain batch size consistency in the output.
                    all_image_embeddings.append(np.zeros(1408)) # Append zero vector of dimension 1408


            if not all_image_embeddings:
                 logger.warning("No image embeddings were generated.")
                 # Return dummy output if no embeddings were generated
                 return torch.zeros(images.size(0), self.num_classes, device=images.device)


            # Calculate similarity scores for each image embedding with all text embeddings
            batch_similarity_scores = []
            for img_embedding in all_image_embeddings:
                similarity_scores = [self.calculate_cosine_similarity(img_embedding, text_emb) for text_emb in self.category_text_embeddings]
                batch_similarity_scores.append(similarity_scores)

            # Convert similarity scores to logits (higher similarity = higher logit)
            logits = torch.tensor(np.array(batch_similarity_scores), dtype=torch.float32)

            # Move logits to the desired device
            logits = logits.to(images.device)

            return logits

        except Exception as e:
            logger.error(f"Error generating image embeddings or classifying: {e}")
            # Return dummy output
            return torch.zeros(images.size(0), self.num_classes, device=images.device)

    # MultimodalEmbeddingModel does not have state_dict in the traditional PyTorch sense
    # Override state_dict and load_state_dict to handle the classifier (which is now None)
    # and text embeddings if we were to save/load them.
    def state_dict(self):
        # We can save the text embeddings if needed, but for this approach, they are
        # re-generated during initialization based on category names.
        # Saving the state_dict isn't really applicable for this type of model.
        # Returning an empty dict or raising an error might be appropriate.
        # For compatibility with the existing framework, return a minimal dict.
        return {}

    def load_state_dict(self, state_dict):
        # No state to load for this model using this method.
        logger.warning("load_state_dict is not applicable for MultimodalEmbeddingWrapper in similarity mode.")


    @classmethod
    def load_from_checkpoint(cls, model_path, project_id, location, device=DEVICE):
        """Loading from checkpoint is not applicable for this wrapper."""
        logger.warning("load_from_checkpoint is not applicable for MultimodalEmbeddingWrapper in similarity mode.")
        # Return a new instance initialized from scratch
        return cls(project_id=project_id, location=location)

### Utility: Checkpointing during Training

In [27]:
# --- Checkpoint Utility Functions ---
def save_checkpoint(model, optimizer, scheduler, epoch, best_val_accuracy, early_stopping_counter, model_name, is_best=False):
    """
    Save a checkpoint of the model, optimizer, scheduler and training state with git info.
    """
    # Create checkpoint directory if it doesn't exist
    checkpoint_dir = os.path.join(MODEL_SAVE_PATH, "checkpoints", model_name)
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Get git info
    git_info = get_git_info()

    # Prepare checkpoint data
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
        'best_val_accuracy': best_val_accuracy,
        'early_stopping_counter': early_stopping_counter,
        'git_info': git_info
    }


    # Save latest epoch checkpoint (always overwriting previous)
    latest_path = os.path.join(checkpoint_dir, f'checkpoint_latest{LATEST_PATH}.pth')
    torch.save(checkpoint, latest_path)
    logger.info(f"Latest checkpoint saved at {latest_path}")

    # If this is the best model so far, save it separately
    if is_best:
        best_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}_best{LATEST_PATH}.pth")
        torch.save(model.state_dict(), best_path)
        logger.info(f"Best model saved at {best_path}")

        # Also save a complete checkpoint for the best model
        best_checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_best{LATEST_PATH}.pth')
        torch.save(checkpoint, best_checkpoint_path)

def load_checkpoint(model, optimizer, scheduler, model_name):
    """
    Load the latest checkpoint if it exists.
    """
    checkpoint_dir = os.path.join(MODEL_SAVE_PATH, "checkpoints", model_name)


    latest_path = os.path.join(checkpoint_dir, f'checkpoint_latest{LATEST_PATH}.pth')

    # If the checkpoint file doesn't exist, return initial values
    if not os.path.exists(latest_path):
        logger.info(f"No checkpoint found at {latest_path}, starting from scratch.")
        return 0, 0.0, 0

    # Load the checkpoint
    logger.info(f"Loading checkpoint from {latest_path}")
    try:
        checkpoint = torch.load(latest_path, map_location=DEVICE)

        # Load model weights
        model.load_state_dict(checkpoint['model_state_dict'])

        # Load optimizer state
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Load scheduler state if it exists
        if scheduler is not None and checkpoint['scheduler_state_dict'] is not None:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        # Extract and return training state
        epoch = checkpoint['epoch']
        best_val_accuracy = checkpoint['best_val_accuracy']
        early_stopping_counter = checkpoint['early_stopping_counter']

        logger.info(f"Resuming from epoch {epoch+1} with best validation accuracy: {best_val_accuracy:.2f}%")
        return epoch + 1, best_val_accuracy, early_stopping_counter

    except Exception as e:
        logger.error(f"Error loading checkpoint: {e}")
        return 0, 0.0, 0

# Template code for resuming specific model training
def resume_specific_model_training(model_name):
    # Set up model and optimizer as in the benchmark function
    model_config = MODELS_TO_TEST[model_name]
    model = model_config["model_fn"](weights=model_config["weights"])

    # Set up the classifier based on model type
    # ... (same code as in run_finetuning_benchmark)

    # Initialize optimizer and scheduler
    optimizer = optim.Adam(model.parameters(), lr=FINETUNE_LEARNING_RATE, weight_decay=FINETUNE_WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

    # Load checkpoint
    start_epoch, best_val_accuracy, early_stopping_counter = load_checkpoint(
        model, optimizer, scheduler, model_name
    )

    if start_epoch == 0:
        logger.info(f"No checkpoint found for {model_name}. Starting from scratch.")
    else:
        logger.info(f"Resuming training for {model_name} from epoch {start_epoch}")

    # Continue with training loop as in the benchmark function
    # ...

In [28]:
def save_benchmark_results(results, benchmark_type='finetuning', base_dir='./results'):
    """
    Saves benchmark results to JSON and CSV files with git commit info and parameters.
    """
    # Create results directory if it doesn't exist
    os.makedirs(base_dir, exist_ok=True)



    base_filename = f"{benchmark_type}{LATEST_PATH}"

    # Prepare paths
    json_path = os.path.join(base_dir, f"{base_filename}.json")
    csv_path = os.path.join(base_dir, f"{base_filename}.csv")

    # Add metadata
    metadata = {
        "git_info": git_info,
        "device": str(DEVICE),
        "num_categories": len(QUICKDRAW_CATEGORIES),
        "samples_per_category": NUM_TRAIN_SAMPLES_PER_CATEGORY,
        "test_samples_per_category": NUM_TEST_SAMPLES_PER_CATEGORY,
        "finetune_epochs": NUM_FINETUNE_EPOCHS
    }

    # Save JSON with metadata
    with open(json_path, 'w') as f:
        json_data = {
            "metadata": metadata,
            "results": results
        }
        json.dump(json_data, f, indent=2)

    # Save CSV
    if results and len(results) > 0:
        fieldnames = set()
        for result in results:
            fieldnames.update(result.keys())

        with open(csv_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=sorted(fieldnames))
            writer.writeheader()
            writer.writerows(results)

    logger.info(f"Benchmark results saved to {json_path} and {csv_path}")
    return json_path, csv_path

# Benchmarks

### Benchmark models

---



In [29]:
import time
import pandas as pd

def benchmark_models_on_test_set(models_to_benchmark, test_loader, device=DEVICE):
    """
    Benchmarks a dictionary of models on the test dataset.

    Args:
        models_to_benchmark (dict): A dictionary where keys are model names (str)
                                    and values are loaded PyTorch models (nn.Module).
        test_loader (DataLoader): DataLoader for the test dataset.
        device (torch.device): The device to run inference on.

    Returns:
        list: A list of dictionaries, each containing benchmark results for a model.
    """
    benchmark_results = []

    logger.info(f"\n--- Starting Benchmarking on Test Set ---")
    logger.info(f"Using device: {device}")

    # Calculate number of images per category and total images in the test set
    num_images_per_category = len(test_loader.dataset) // len(QUICKDRAW_CATEGORIES)
    total_test_images = len(test_loader.dataset)
    logger.info(f"Benchmarking with {total_test_images} total images ({num_images_per_category} images per category).")


    for model_name, model in tqdm(models_to_benchmark.items(), desc="Benchmarking Models", unit="model"):
        logger.info(f"Benchmarking: {model_name}")

        # Ensure model is on the correct device and in evaluation mode
        model.to(device)
        model.eval()

        test_correct = 0
        test_total = 0
        inference_time = 0

        with torch.no_grad():
            start_inference_time = time.time()
            for images, labels in tqdm(test_loader, desc=f"  Evaluating {model_name}", leave=False): # Added tqdm here
                # MultimodalEmbeddingWrapper expects PIL images, others expect tensors
                if isinstance(model, MultimodalEmbeddingWrapper):
                     # The MultimodalEmbeddingWrapper handles its own data loading/processing internally
                     # We just need to provide the raw images (as tensors) and labels for comparison
                     # The wrapper's forward method will handle the conversion to PIL and API calls
                     images_for_model = images # Pass the tensor batch
                else:
                    images_for_model, labels = images.to(DEVICE), labels.to(DEVICE)


                # Forward pass
                outputs = model(images_for_model)

                # Access the logits from the model output
                if hasattr(outputs, 'logits'):
                    logits = outputs.logits.to(DEVICE) # Move logits to device
                else:
                    # Assume the output is already logits or the final layer output
                    logits = outputs.to(DEVICE) # Move outputs to device

                _, predicted = torch.max(logits, 1)
                test_correct += (predicted == labels.to(DEVICE)).sum().item()
                test_total += labels.size(0)

            inference_time = time.time() - start_inference_time

        # Calculate accuracy and inference speed
        test_accuracy = 100 * test_correct / test_total if test_total > 0 else 0
        inference_speed = test_total / inference_time if inference_time > 0 else float('inf')

        logger.info(f"  Test Accuracy: {test_accuracy:.2f}%, Inference Time: {inference_time:.4f}s, Inference Speed: {inference_speed:.2f} img/s")

        benchmark_results.append({
            'Model Name': model_name,
            'Test Accuracy (%)': f"{test_accuracy:.2f}",
            'Inference Time (s)': f"{inference_time:.4f}",
            'Inference Speed (img/s)': f"{inference_speed:.2f}",
            'Images per Category': num_images_per_category,
            'Total Test Images': total_test_images # Added this field
        })

        # Clear GPU memory after each model evaluation
        clear_gpu_memory()

    logger.info(f"\n--- Benchmarking Completed ---")
    return benchmark_results

# Example usage:
# Assuming you have a dictionary of loaded models and a test_loader from data preparation

# Load the test dataset (if not already loaded)
_, _, test_dataset = get_augmented_quickdraw_data(
    QUICKDRAW_CATEGORIES,
    NUM_TRAIN_SAMPLES_PER_CATEGORY,
    NUM_TEST_SAMPLES_PER_CATEGORY,
    BINARY_DATA_ROOT
)

# Create a test DataLoader
# Calculate adaptive batch size based on dataset size (using test samples)
effective_batch_size_test = get_adaptive_batch_size(
    NUM_TEST_SAMPLES_PER_CATEGORY, len(QUICKDRAW_CATEGORIES))
logger.info(f"Using adaptive batch size of {effective_batch_size_test} for benchmarking")

test_loader = DataLoader(test_dataset, batch_size=effective_batch_size_test,
                         shuffle=False, num_workers=4, pin_memory=True, prefetch_factor=3)


# Load models to benchmark
benchmarked_models = {}

# Load MobileNetV3-Small
try:
    mobilenet_wrapper = MobileNetV3Wrapper()
    # MODIFIED: Ensure the filename matches the saving format with LATEST_PATH
    mobilenet_path = "/content/models/MobileNetV3-Small_best__Checkpoint..._samples8000_epochs20_classes51.pth"
    mobilenet_wrapper.model.load_state_dict(torch.load(mobilenet_path))
    benchmarked_models["MobileNetV3-Small"] = mobilenet_wrapper.model
    logger.info(f"Loaded MobileNetV3-Small from {mobilenet_path}")
except Exception as e:
    logger.warning(f"Could not load MobileNetV3-Small for benchmarking: {e}")

# Load ViT directly
try:
    from transformers import AutoTokenizer, AutoModelForImageClassification
    vit_model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
    # Assuming the ViT model needs a classifier head for the number of QuickDraw classes
    if hasattr(vit_model, 'classifier'):
         num_ftrs = vit_model.classifier.in_features
         vit_model.classifier = nn.Linear(num_ftrs, len(QUICKDRAW_CATEGORIES))
    elif hasattr(vit_model, 'head'):
         num_ftrs = vit_model.head.in_features
         vit_model.head = nn.Linear(num_ftrs, len(QUICKDRAW_CATEGORIES))
    else:
         logger.warning(f"Could not find a standard classifier head in ViT model. Manual adaptation may be needed.")

    # Try loading a fine-tuned checkpoint if it exists, otherwise use the pre-trained model with modified head
    try:
        vit_path = f"{MODEL_SAVE_PATH}/ViT_best{LATEST_PATH}.pth"
        vit_model.load_state_dict(torch.load(vit_path))
        logger.info(f"Loaded fine-tuned ViT from {vit_path}")
    except Exception as e:
        logger.warning(f"Could not load fine-tuned ViT from {vit_path}: {e}. Using pre-trained ViT with modified head.")

    benchmarked_models["ViT"] = vit_model
    logger.info("Loaded ViT model for benchmarking.")
except Exception as e:
    logger.warning(f"Could not load ViT for benchmarking: {e}")


# Load ShuffleNetV2_x0_5
try:
    shufflenet_wrapper = ShuffleNetV2Wrapper()
    shufflenet_path = f"{MODEL_SAVE_PATH}/ShuffleNetV2_x0_5_best{LATEST_PATH}.pth"
    shufflenet_wrapper.model.load_state_dict(torch.load(shufflenet_path))
    benchmarked_models["ShuffleNetV2_x0_5"] = shufflenet_wrapper.model
    logger.info(f"Loaded ShuffleNetV2_x0_5 from {shufflenet_path}")
except Exception as e:
    logger.warning(f"Could not load ShuffleNetV2_x0_5 for benchmarking: {e}")

# Load SqueezeNet1_1
try:
    squeezenet_wrapper = SqueezeNetWrapper()
    squeezenet_path = f"{MODEL_SAVE_PATH}/SqueezeNet1_1_best{LATEST_PATH}.pth"
    squeezenet_wrapper.model.load_state_dict(torch.load(squeezenet_path))
    benchmarked_models["SqueezeNet1_1"] = squeezenet_wrapper.model
    logger.info(f"Loaded SqueezeNet1_1 from {squeezenet_path}")
except Exception as e:
    logger.warning(f"Could not load SqueezeNet1_1 for benchmarking: {e}")

# Load MultimodalEmbeddingModel
try:
    # Make sure you have authenticated using `google.colab.auth.authenticate_user()`
    if not PROJECT_ID and not LOCATION:
         logger.warning("PROJECT_ID or LOCATION not defined. Cannot load MultimodalEmbeddingModel.")
    else:
         multimodal_wrapper = MultimodalEmbeddingWrapper(project_id=PROJECT_ID, location=LOCATION, categories=QUICKDRAW_CATEGORIES)
         # The MultimodalEmbeddingWrapper's classifier needs to be trained or loaded
         # For benchmarking, we'll assume a classifier has been trained and saved.
         # Attempt to load a saved classifier state dict
         try:
             multimodal_path = f"{MODEL_SAVE_PATH}/MultimodalEmbeddingWrapper_classifier_best{LATEST_PATH}.pth"
             multimodal_wrapper.load_state_dict(torch.load(multimodal_path))
             logger.info(f"Loaded MultimodalEmbeddingWrapper classifier from {multimodal_path}")
         except Exception as e:
             logger.warning(f"Could not load MultimodalEmbeddingWrapper classifier from {multimodal_path}: {e}. Using untrained classifier.")

         benchmarked_models["MultimodalEmbedding"] = multimodal_wrapper
         logger.info("Loaded MultimodalEmbedding model for benchmarking.")

except Exception as e:
    logger.warning(f"Could not load MultimodalEmbeddingModel for benchmarking: {e}")


# Check if any models were loaded
if not benchmarked_models:
    logger.error("No models were loaded for benchmarking. Please run the fine-tuning benchmark first.")
else:
    # Run the benchmarking
    benchmark_results = benchmark_models_on_test_set(benchmarked_models, test_loader)

    # Display results in a table using pandas
    results_df = pd.DataFrame(benchmark_results)
    display(results_df)

    # Save the benchmark results
    save_benchmark_results(benchmark_results, benchmark_type='inference_benchmark')

2025-09-01 11:46:32,207 - INFO - ipython-input-2401830890 - Loading augmented QuickDraw data for 51 categories...
INFO:__main__:Loading augmented QuickDraw data for 51 categories...
2025-09-01 11:46:32,210 - INFO - ipython-input-2401830890 - Loading QuickDraw category: apple from local binary files...
INFO:__main__:Loading QuickDraw category: apple from local binary files...
2025-09-01 11:46:32,213 - INFO - ipython-input-2928541081 - Successfully loaded or indexed 144722 drawings for apple. Cache capacity: 50000 items.
INFO:__main__:Successfully loaded or indexed 144722 drawings for apple. Cache capacity: 50000 items.
2025-09-01 11:46:32,214 - INFO - ipython-input-2401830890 - Category apple: 100 train, 10 val, 20 test
INFO:__main__:Category apple: 100 train, 10 val, 20 test
2025-09-01 11:46:32,280 - INFO - ipython-input-2401830890 - Loading QuickDraw category: cat from local binary files...
INFO:__main__:Loading QuickDraw category: cat from local binary files...
2025-09-01 11:46:32,28

Generating text embeddings:   0%|          | 0/51 [00:00<?, ?it/s]

2025-09-01 11:47:23,827 - INFO - ipython-input-1337937114 - Initialized MultimodalEmbeddingWrapper for similarity-based classification.
INFO:__main__:Initialized MultimodalEmbeddingWrapper for similarity-based classification.
2025-09-01 11:47:23,831 - INFO - ipython-input-1511357590 - Loaded MultimodalEmbedding model for benchmarking.
INFO:__main__:Loaded MultimodalEmbedding model for benchmarking.
2025-09-01 11:47:23,832 - INFO - ipython-input-1511357590 - 
--- Starting Benchmarking on Test Set ---
INFO:__main__:
--- Starting Benchmarking on Test Set ---
2025-09-01 11:47:23,834 - INFO - ipython-input-1511357590 - Using device: cuda
INFO:__main__:Using device: cuda
2025-09-01 11:47:23,836 - INFO - ipython-input-1511357590 - Benchmarking with 1020 total images (20 images per category).
INFO:__main__:Benchmarking with 1020 total images (20 images per category).


Benchmarking Models:   0%|          | 0/3 [00:00<?, ?model/s]

2025-09-01 11:47:23,853 - INFO - ipython-input-1511357590 - Benchmarking: MobileNetV3-Small
INFO:__main__:Benchmarking: MobileNetV3-Small


  Evaluating MobileNetV3-Small:   0%|          | 0/8 [00:00<?, ?it/s]

2025-09-01 11:47:26,825 - INFO - ipython-input-1511357590 -   Test Accuracy: 89.80%, Inference Time: 2.9552s, Inference Speed: 345.16 img/s
INFO:__main__:  Test Accuracy: 89.80%, Inference Time: 2.9552s, Inference Speed: 345.16 img/s
2025-09-01 11:47:26,837 - INFO - ipython-input-1511357590 - Benchmarking: ViT
INFO:__main__:Benchmarking: ViT


  Evaluating ViT:   0%|          | 0/8 [00:00<?, ?it/s]

2025-09-01 11:47:39,634 - INFO - ipython-input-1511357590 -   Test Accuracy: 2.55%, Inference Time: 12.6621s, Inference Speed: 80.56 img/s
INFO:__main__:  Test Accuracy: 2.55%, Inference Time: 12.6621s, Inference Speed: 80.56 img/s
2025-09-01 11:47:39,651 - INFO - ipython-input-1511357590 - Benchmarking: MultimodalEmbedding
INFO:__main__:Benchmarking: MultimodalEmbedding


  Evaluating MultimodalEmbedding:   0%|          | 0/8 [00:00<?, ?it/s]

2025-09-01 11:47:41,223 - INFO - ipython-input-1337937114 - Loading image embeddings from ./embeddings/quickdraw_image_embeddings_categories51_samples100.pkl
INFO:__main__:Loading image embeddings from ./embeddings/quickdraw_image_embeddings_categories51_samples100.pkl
2025-09-01 11:47:41,265 - INFO - ipython-input-1337937114 - Successfully loaded 124 image embeddings.
INFO:__main__:Successfully loaded 124 image embeddings.
2025-09-01 11:47:41,351 - ERROR - ipython-input-1337937114 - Mismatch between loaded embeddings (124) and input batch size (128). Returning dummy output.
ERROR:__main__:Mismatch between loaded embeddings (124) and input batch size (128). Returning dummy output.
2025-09-01 11:47:41,371 - INFO - ipython-input-1337937114 - Loading image embeddings from ./embeddings/quickdraw_image_embeddings_categories51_samples100.pkl
INFO:__main__:Loading image embeddings from ./embeddings/quickdraw_image_embeddings_categories51_samples100.pkl
2025-09-01 11:47:41,418 - INFO - ipython

Unnamed: 0,Model Name,Test Accuracy (%),Inference Time (s),Inference Speed (img/s),Images per Category,Total Test Images
0,MobileNetV3-Small,89.8,2.9552,345.16,20,1020
1,ViT,2.55,12.6621,80.56,20,1020
2,MultimodalEmbedding,9.12,2.8459,358.41,20,1020


2025-09-01 11:47:42,536 - INFO - ipython-input-1424645408 - Benchmark results saved to ./results/inference_benchmark__samples100_epochs20_classes51.json and ./results/inference_benchmark__samples100_epochs20_classes51.csv
INFO:__main__:Benchmark results saved to ./results/inference_benchmark__samples100_epochs20_classes51.json and ./results/inference_benchmark__samples100_epochs20_classes51.csv


In [None]:
from google.colab import sheets
sheet = sheets.InteractiveSheet(df=results_df)

https://docs.google.com/spreadsheets/d/1eYqehwmC5TNrU1Fr0CzkYzL81Xo-XcpW2OhWju2oJGs/edit#gid=0


### Finetune and Benchmark

In [None]:
# --- Enhanced Fine-tuning Benchmarking Loop (Part 1) ---
import copy
import random
import torch.optim as optim
import torch.nn as nn


def run_finetuning_benchmark():
    finetuning_results = []

    # Get data with augmentation
    train_dataset, val_dataset, test_dataset = get_augmented_quickdraw_data(
        QUICKDRAW_CATEGORIES,
        NUM_TRAIN_SAMPLES_PER_CATEGORY,
        NUM_TEST_SAMPLES_PER_CATEGORY,
        BINARY_DATA_ROOT
    )

    # In run_finetuning_benchmark function:
    # Calculate adaptive batch size based on dataset size
    effective_batch_size = get_adaptive_batch_size(
        NUM_TRAIN_SAMPLES_PER_CATEGORY, len(QUICKDRAW_CATEGORIES))
    logger.info(
        f"Using adaptive batch size of {effective_batch_size} for feature extraction")

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=effective_batch_size,
                              shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=3)
    val_loader = DataLoader(val_dataset, batch_size=effective_batch_size,
                            shuffle=False, num_workers=4, pin_memory=True, prefetch_factor=3)
    test_loader = DataLoader(test_dataset, batch_size=effective_batch_size,
                             shuffle=False, num_workers=4, pin_memory=True, prefetch_factor=3)

    logger.info(f"\n--- Starting Enhanced Fine-tuning Benchmark ---")
    logger.info(f"Using device: {DEVICE}")
    logger.info(
        f"Fine-tuning on QuickDraw categories: {', '.join(QUICKDRAW_CATEGORIES)}")
    logger.info(
        f"Samples per category: {NUM_TRAIN_SAMPLES_PER_CATEGORY} train, {NUM_TEST_SAMPLES_PER_CATEGORY} test.")
    logger.info(
        f"Number of fine-tuning epochs: {NUM_FINETUNE_EPOCHS}, Learning Rate: {FINETUNE_LEARNING_RATE}")
    logger.info(
        f"Weight Decay: {FINETUNE_WEIGHT_DECAY}, Gradual Unfreezing: {USE_GRADUAL_UNFREEZING}")
    logger.info(
        f"Loading QuickDraw data from local directory: {os.path.abspath(BINARY_DATA_ROOT)}\n")

    num_classes = len(QUICKDRAW_CATEGORIES)

    for model_name, config in tqdm(MODELS_TO_TEST.items(), desc="Benchmarking Models (Enhanced Fine-tuning)", unit="model"):
        logger.info(f"--- Enhanced Fine-tuning Model: {model_name} ---")

        # Clear GPU memory before loading a new model
        clear_gpu_memory()

        current_ft_accuracy = "Error"
        current_ft_train_time = "N/A"
        current_ft_inference_time = "N/A"
        current_ft_inference_speed = "N/A"
        model_params_ft = "Error"
        current_model_path = "Not saved"
        best_val_accuracy = 0.0

        try:
            weights = config["weights"]
            model_to_finetune = config["model_fn"](weights=weights)
            model_params_ft = sum(
                p.numel() for p in model_to_finetune.parameters() if p.requires_grad) / 1_000_000

            # Create custom classifier heads with improved architectures
            if model_name == "MobileNetV3-Small":
                # Save the feature extraction part
                features = model_to_finetune.features
                avgpool = model_to_finetune.avgpool

                # Replace classifier with a better one
                in_features = model_to_finetune.classifier[0].in_features
                model_to_finetune.classifier = nn.Sequential(
                    nn.Linear(in_features, 1024),
                    nn.BatchNorm1d(1024),
                    nn.ReLU(inplace=True),
                    nn.Dropout(0.4),
                    nn.Linear(1024, 512),
                    nn.BatchNorm1d(512),
                    nn.ReLU(inplace=True),
                    nn.Dropout(0.3),
                    nn.Linear(512, num_classes)
                )

                # Define the feature layers for gradual unfreezing
                feature_layers = [features[0], features[1], features[2], features[3],
                                  features[4], features[5], features[6], features[7],
                                  features[8], features[9], features[10], features[11],
                                  avgpool]

            elif model_name == "SqueezeNet1_1":
                # Get the input channels of the classifier
                in_channels = model_to_finetune.classifier[1].in_channels

                # More sophisticated conv classifier
                model_to_finetune.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv2d(in_channels, 512, kernel_size=1),
                    nn.ReLU(inplace=True),
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Conv2d(512, num_classes, kernel_size=1)
                )

                # Define feature layers for gradual unfreezing
                feature_layers = [
                    model_to_finetune.features[0],  # conv1
                    model_to_finetune.features[1],  # maxpool
                    model_to_finetune.features[2],  # fire1
                    model_to_finetune.features[3],  # fire2
                    model_to_finetune.features[4],  # fire3
                    model_to_finetune.features[5],  # fire4
                    model_to_finetune.features[6],  # maxpool
                    model_to_finetune.features[7],  # fire5
                    model_to_finetune.features[8],  # fire6
                    model_to_finetune.features[9],  # fire7
                    model_to_finetune.features[10],  # fire8
                    model_to_finetune.features[11],  # maxpool
                    model_to_finetune.features[12]  # conv10
                ]
            elif model_name == "ShuffleNetV2_x0_5":
                # Get the number of input features for the classifier
                in_features = model_to_finetune.fc.in_features

                # Replace the classifier with an optimized version for better learning
                model_to_finetune.fc = nn.Sequential(
                    nn.Linear(in_features, 1024),
                    nn.BatchNorm1d(1024),
                    nn.ReLU(inplace=True),
                    nn.Dropout(0.4),
                    nn.Linear(1024, 512),
                    nn.BatchNorm1d(512),
                    nn.ReLU(inplace=True),
                    nn.Dropout(0.3),
                    nn.Linear(512, num_classes)
                )

                # Define feature layers for gradual unfreezing
                # This progressive unfreezing helps with learning transfer
                feature_layers = [
                    model_to_finetune.conv1,
                    model_to_finetune.maxpool,
                    model_to_finetune.stage2,
                    model_to_finetune.stage3,
                    model_to_finetune.stage4,
                    model_to_finetune.conv5
                ]

                # Apply a higher initial learning rate for ShuffleNetV2
                # This helps overcome the vanishing gradient problem in lightweight models
                optimizer = optim.Adam(model_to_finetune.parameters(),
                                       lr=FINETUNE_LEARNING_RATE * 1.5,
                                       weight_decay=FINETUNE_WEIGHT_DECAY * 0.8)
            else:
                logger.warning(
                    f"Classifier modification not defined for {model_name}. Skipping fine-tuning.")
                continue

            model_to_finetune.to(DEVICE)

            # Enable gradient checkpointing if configured
            if USE_GRADIENT_CHECKPOINTING and hasattr(model_to_finetune, 'features'):
                logger.info(f"Enabling gradient checkpointing for {model_name}")
                try:
                    # Check if the features module has gradient_checkpointing_enable
                    if hasattr(model_to_finetune.features, 'gradient_checkpointing_enable'):
                        model_to_finetune.features.gradient_checkpointing_enable()
                    # For models where features is a Sequential module
                    elif isinstance(model_to_finetune.features, nn.Sequential):
                        # Skip gradient checkpointing for Sequential modules
                        logger.info(f"Gradient checkpointing not available for {model_name} with Sequential features")
                    else:
                        logger.info(f"Gradient checkpointing not supported for this model architecture")
                except Exception as e:
                    logger.warning(f"Failed to enable gradient checkpointing for {model_name}: {e}")
                    # Continue without gradient checkpointing
            # --- Enhanced Fine-tuning Procedure ---
            logger.info(f"Starting fine-tuning for {model_name}...")
            start_time = time.time()

            # Initialize optimizer, criterion, and scheduler
            optimizer = optim.Adam(model_to_finetune.parameters(
            ), lr=FINETUNE_LEARNING_RATE, weight_decay=FINETUNE_WEIGHT_DECAY)
            criterion = nn.CrossEntropyLoss()
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='max', factor=0.5, patience=2, verbose=True)

            # Initialize early stopping variables
            early_stopping_patience = 3
            early_stopping_counter = 0

            # Check for existing checkpoints and resume if available
            start_epoch = 0
            if RESUME_FROM_CHECKPOINT:
                start_epoch, best_val_accuracy, early_stopping_counter = load_checkpoint(
                    model_to_finetune, optimizer, scheduler, model_name
                )

            # If we're starting from a checkpoint, we need to restore the best model weights
            if start_epoch > 0:
                best_model_wts = copy.deepcopy(model_to_finetune.state_dict())
            else:
                best_model_wts = copy.deepcopy(model_to_finetune.state_dict())

            # Fine-tuning loop
            for epoch in range(start_epoch, NUM_FINETUNE_EPOCHS):

                # Explicitly flush handlers after a significant phase
                for handler in logger.handlers:
                    if isinstance(handler, logging.FileHandler):
                        handler.flush()

                logger.info(f"Epoch {epoch + 1}/{NUM_FINETUNE_EPOCHS}")

                # Adjust learning rate if using gradual unfreezing
                if USE_GRADUAL_UNFREEZING:
                    if epoch < 10:
                        for param in feature_layers[:epoch + 1]:
                            for p in param.parameters():
                                p.requires_grad = True
                    else:
                        for param in feature_layers:
                            for p in param.parameters():
                                p.requires_grad = True

                # Training phase
                model_to_finetune.train()
                train_loss = 0.0
                train_correct = 0
                train_total = 0

                # Calculate needed gradient accumulation steps based on effective batch size
                if effective_batch_size < MAX_BATCH_SIZE:
                    gradient_accumulation_steps = max(
                        1, MAX_BATCH_SIZE // effective_batch_size)
                    logger.info(
                        f"Using gradient accumulation with {gradient_accumulation_steps} steps")
                else:
                    gradient_accumulation_steps = 1

                # Zero gradients at the beginning of epoch
                optimizer.zero_grad()

                for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc="Training", leave=False)):
                    images, labels = images.to(DEVICE), labels.to(DEVICE)

                    # Forward pass
                    outputs = model_to_finetune(images)
                    loss = criterion(outputs, labels) / gradient_accumulation_steps  # Normalize loss

                    # Backward pass
                    loss.backward()

                    # Only update weights after accumulating gradients for specified steps
                    if (batch_idx + 1) % gradient_accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                        optimizer.step()
                        optimizer.zero_grad()

                    train_loss += loss.item() * images.size(0) * gradient_accumulation_steps  # Scale loss back for reporting
                    _, predicted = torch.max(outputs, 1)
                    train_correct += (predicted == labels).sum().item()
                    train_total += labels.size(0)

                # Calculate average loss and accuracy for the epoch
                epoch_train_loss = train_loss / train_total
                epoch_train_accuracy = 100 * train_correct / train_total
                logger.info(
                    f"Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {epoch_train_accuracy:.2f}%")

                # Validation phase
                model_to_finetune.eval()
                val_loss = 0.0
                val_correct = 0
                val_total = 0

                with torch.no_grad():
                    for images, labels in tqdm(val_loader, desc="Validation", leave=False):
                        images, labels = images.to(DEVICE), labels.to(DEVICE)

                        # Forward pass
                        outputs = model_to_finetune(images)
                        loss = criterion(outputs, labels)

                        val_loss += loss.item() * images.size(0)
                        _, predicted = torch.max(outputs, 1)
                        val_correct += (predicted == labels).sum().item()
                        val_total += labels.size(0)

                # Calculate average loss and accuracy for the validation set
                epoch_val_loss = val_loss / val_total
                epoch_val_accuracy = 100 * val_correct / val_total
                logger.info(
                    f"Val Loss: {epoch_val_loss:.4f}, Val Accuracy: {epoch_val_accuracy:.2f}%")

                # Clear GPU memory after each epoch
                clear_gpu_memory()

                # Early stopping logic
                is_best = False
                if epoch_val_accuracy > best_val_accuracy:
                    best_val_accuracy = epoch_val_accuracy
                    best_model_wts = copy.deepcopy(
                        model_to_finetune.state_dict())
                    logger.info(
                        f"New best model found for {model_name}! (Val Accuracy: {best_val_accuracy:.2f}%)")
                    early_stopping_counter = 0
                    is_best = True
                else:
                    early_stopping_counter += 1
                    logger.info(
                        f"Validation accuracy didn't improve. Counter: {early_stopping_counter}/{early_stopping_patience}")

                # Save checkpoint at specified intervals or if it's the best model
                if (epoch + 1) % CHECKPOINT_INTERVAL == 0 or is_best:
                    save_checkpoint(
                        model_to_finetune,
                        optimizer,
                        scheduler,
                        epoch,
                        best_val_accuracy,
                        early_stopping_counter,
                        model_name,
                        is_best=is_best
                    )

                # Check if early stopping criteria is met
                if early_stopping_counter >= early_stopping_patience:
                    logger.info(
                        f"Early stopping triggered after {epoch+1} epochs")
                    break

                # Learning rate scheduler
                scheduler.step(epoch_val_accuracy)
                logger.info(
                    f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")

            # Load the best model weights
            model_to_finetune.load_state_dict(best_model_wts)

            # Now save the final model (end of training)
            final_model_path = f"{MODEL_SAVE_PATH}/{model_name}_final{LATEST_PATH}.pth"
            # Store current state before overwriting
            final_model_wts = copy.deepcopy(model_to_finetune.state_dict())
            # Save the final state
            torch.save(final_model_wts, final_model_path)
            logger.info(f"Final model saved: {final_model_path}")

            # Record total training time
            training_time = time.time() - start_time
            current_ft_train_time = f"{training_time:.2f}"
            logger.info(f"Total training time: {current_ft_train_time}s")

            # Evaluate the model on the test set
            test_loss = 0.0
            test_correct = 0
            test_total = 0

            # Track inference time
            start_inference_time = time.time()
            with torch.no_grad():
                for images, labels in tqdm(test_loader, desc="Testing", leave=False):
                    images, labels = images.to(DEVICE), labels.to(DEVICE)

                    # Forward pass
                    outputs = model_to_finetune(images)
                    loss = criterion(outputs, labels)

                    test_loss += loss.item() * images.size(0)
                    _, predicted = torch.max(outputs, 1)
                    test_correct += (predicted == labels).sum().item()
                    test_total += labels.size(0)

            inference_time = time.time() - start_inference_time
            current_ft_inference_time = f"{inference_time:.4f}"
            current_ft_inference_speed = f"{test_total / inference_time:.2f}"

            # Calculate average loss and accuracy for the test set
            test_loss /= test_total
            test_accuracy = 100 * test_correct / test_total
            logger.info(
                f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

            # Append the results for this model to the overall results
            finetuning_results.append({
                'model_name': model_name,
                'test_loss': test_loss,
                'test_accuracy': test_accuracy,
                'params': model_params_ft,
                'train_time': current_ft_train_time,
                'inference_time': current_ft_inference_time,
                'inference_speed': current_ft_inference_speed,
                'best_val_accuracy': best_val_accuracy
            })
        except Exception as e:
            logger.error(f"Error occurred while fine-tuning {model_name}: {e}")
            finetuning_results.append({
                'model_name': model_name,
                'test_loss': float('nan'),
                'test_accuracy': 0,
                'params': model_params_ft if isinstance(model_params_ft, (int, float)) else float('nan'),
                'train_time': 'N/A',
                'inference_time': 'N/A',
                'inference_speed': 'N/A',
                'best_val_accuracy': 0,
                'error': str(e)
            })

    # --- Benchmarking Results ---
    logger.info(f"\n--- Enhanced Fine-tuning Benchmarking Results ---")
    for result in finetuning_results:
        if 'error' in result:
            logger.info(
                f"Model: {result['model_name']}, Error: {result['error']}")
        else:
            logger.info(f"Model: {result['model_name']}, Test Loss: {result['test_loss']:.4f}, Test Accuracy: {result['test_accuracy']:.2f}%, Params: {result['params']:.2f}M, Train Time: {result['train_time']}, Inference Time: {result['inference_time']}, Inference Speed: {result['inference_speed']}")

    logger.info(f"Best model(s) based on test accuracy:")
    best_models = sorted(finetuning_results, key=lambda x: x.get(
        'test_accuracy', 0), reverse=True)[:3]
    for bm in best_models:
        logger.info(f" - {bm['model_name']}: {bm['test_accuracy']:.2f}%")

    logger.info(f"Best model(s) based on validation accuracy:")
    best_val_models = sorted(finetuning_results, key=lambda x: x.get(
        'best_val_accuracy', 0), reverse=True)[:3]
    for bvm in best_val_models:
        logger.info(f" - {bvm['model_name']}: {bvm['best_val_accuracy']:.2f}%")

    logger.info(f"\n--- Enhanced Fine-tuning Benchmark Completed ---")

    save_benchmark_results(finetuning_results, benchmark_type='finetuning')
    return finetuning_results

In [None]:
# Run Benchmark for Fine-tuning
run_finetuning_benchmark()

2025-06-01 20:36:50,172 - INFO - 3059701004 - Loading augmented QuickDraw data for 3 categories...
2025-06-01 20:36:50,173 - INFO - 3059701004 - Loading QuickDraw category: apple from local binary files...
2025-06-01 20:36:50,174 - INFO - 4131335437 - Successfully loaded or indexed 144722 drawings for apple. Cache capacity: 50000 items.
2025-06-01 20:36:50,174 - INFO - 3059701004 - Category apple: 8000 train, 1600 val, 2000 test
2025-06-01 20:36:50,224 - INFO - 3059701004 - Loading QuickDraw category: cat from local binary files...
2025-06-01 20:36:50,225 - INFO - 4131335437 - Successfully loaded or indexed 123202 drawings for cat. Cache capacity: 50000 items.
2025-06-01 20:36:50,225 - INFO - 3059701004 - Category cat: 8000 train, 1600 val, 2000 test
2025-06-01 20:36:50,268 - INFO - 3059701004 - Loading QuickDraw category: dog from local binary files...
2025-06-01 20:36:50,268 - INFO - 4131335437 - Successfully loaded or indexed 152159 drawings for dog. Cache capacity: 50000 items.
202

Benchmarking Models (Enhanced Fine-tuning):   0%|          | 0/3 [00:00<?, ?model/s]

2025-06-01 20:36:50,336 - INFO - 2362377068 - --- Enhanced Fine-tuning Model: MobileNetV3-Small ---
2025-06-01 20:36:50,666 - INFO - 2362377068 - Enabling gradient checkpointing for MobileNetV3-Small
2025-06-01 20:36:50,667 - INFO - 2362377068 - Gradient checkpointing not available for MobileNetV3-Small with Sequential features
2025-06-01 20:36:50,668 - INFO - 2362377068 - Starting fine-tuning for MobileNetV3-Small...
2025-06-01 20:36:50,670 - INFO - 1892375343 - No checkpoint found at ./models/checkpoints/MobileNetV3-Small/checkpoint_latest__Refactored..._samples8000_epochs2.pth, starting from scratch.
2025-06-01 20:36:50,685 - INFO - 2362377068 - Epoch 1/2
2025-06-01 20:36:50,686 - INFO - 2362377068 - Using gradient accumulation with 2 steps


Training:   0%|          | 0/375 [00:00<?, ?it/s]

2025-06-01 20:37:44,036 - INFO - 2362377068 - Train Loss: 0.2678, Train Accuracy: 89.23%


Validation:   0%|          | 0/75 [00:00<?, ?it/s]

2025-06-01 20:37:49,966 - INFO - 2362377068 - Val Loss: 0.1866, Val Accuracy: 92.58%
2025-06-01 20:37:50,018 - INFO - 2362377068 - New best model found for MobileNetV3-Small! (Val Accuracy: 92.58%)
2025-06-01 20:37:50,083 - INFO - 1892375343 - Latest checkpoint saved at ./models/checkpoints/MobileNetV3-Small/checkpoint_latest__Refactored..._samples8000_epochs2.pth
2025-06-01 20:37:50,106 - INFO - 1892375343 - Best model saved at ./models/MobileNetV3-Small_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:37:50,159 - INFO - 2362377068 - Current learning rate: 0.000500
2025-06-01 20:37:50,159 - INFO - 2362377068 - Epoch 2/2
2025-06-01 20:37:50,161 - INFO - 2362377068 - Using gradient accumulation with 2 steps


Training:   0%|          | 0/375 [00:00<?, ?it/s]

2025-06-01 20:38:42,800 - INFO - 2362377068 - Train Loss: 0.1986, Train Accuracy: 92.05%


Validation:   0%|          | 0/75 [00:00<?, ?it/s]

2025-06-01 20:38:48,072 - INFO - 2362377068 - Val Loss: 0.1702, Val Accuracy: 93.25%
2025-06-01 20:38:48,127 - INFO - 2362377068 - New best model found for MobileNetV3-Small! (Val Accuracy: 93.25%)
2025-06-01 20:38:48,232 - INFO - 1892375343 - Latest checkpoint saved at ./models/checkpoints/MobileNetV3-Small/checkpoint_latest__Refactored..._samples8000_epochs2.pth
2025-06-01 20:38:48,269 - INFO - 1892375343 - Best model saved at ./models/MobileNetV3-Small_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:38:48,352 - INFO - 2362377068 - Current learning rate: 0.000500
2025-06-01 20:38:48,393 - INFO - 2362377068 - Final model saved: ./models/MobileNetV3-Small_final__Refactored..._samples8000_epochs2.pth
2025-06-01 20:38:48,394 - INFO - 2362377068 - Total training time: 117.72s


Testing:   0%|          | 0/94 [00:00<?, ?it/s]

2025-06-01 20:38:54,994 - INFO - 2362377068 - Test Loss: 0.1837, Test Accuracy: 92.65%
2025-06-01 20:38:54,996 - INFO - 2362377068 - --- Enhanced Fine-tuning Model: ShuffleNetV2_x0_5 ---
2025-06-01 20:38:55,057 - INFO - 2362377068 - Starting fine-tuning for ShuffleNetV2_x0_5...
2025-06-01 20:38:55,060 - INFO - 1892375343 - No checkpoint found at ./models/checkpoints/ShuffleNetV2_x0_5/checkpoint_latest__Refactored..._samples8000_epochs2.pth, starting from scratch.
2025-06-01 20:38:55,078 - INFO - 2362377068 - Epoch 1/2
2025-06-01 20:38:55,080 - INFO - 2362377068 - Using gradient accumulation with 2 steps


Training:   0%|          | 0/375 [00:00<?, ?it/s]

2025-06-01 20:39:47,588 - INFO - 2362377068 - Train Loss: 0.3039, Train Accuracy: 87.66%


Validation:   0%|          | 0/75 [00:00<?, ?it/s]

2025-06-01 20:39:53,418 - INFO - 2362377068 - Val Loss: 0.2451, Val Accuracy: 90.54%
2025-06-01 20:39:53,463 - INFO - 2362377068 - New best model found for ShuffleNetV2_x0_5! (Val Accuracy: 90.54%)
2025-06-01 20:39:53,538 - INFO - 1892375343 - Latest checkpoint saved at ./models/checkpoints/ShuffleNetV2_x0_5/checkpoint_latest__Refactored..._samples8000_epochs2.pth
2025-06-01 20:39:53,564 - INFO - 1892375343 - Best model saved at ./models/ShuffleNetV2_x0_5_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:39:53,628 - INFO - 2362377068 - Current learning rate: 0.000500
2025-06-01 20:39:53,629 - INFO - 2362377068 - Epoch 2/2
2025-06-01 20:39:53,631 - INFO - 2362377068 - Using gradient accumulation with 2 steps


Training:   0%|          | 0/375 [00:00<?, ?it/s]

2025-06-01 20:40:47,894 - INFO - 2362377068 - Train Loss: 0.2198, Train Accuracy: 91.14%


Validation:   0%|          | 0/75 [00:00<?, ?it/s]

2025-06-01 20:40:53,272 - INFO - 2362377068 - Val Loss: 0.1926, Val Accuracy: 92.15%
2025-06-01 20:40:53,316 - INFO - 2362377068 - New best model found for ShuffleNetV2_x0_5! (Val Accuracy: 92.15%)
2025-06-01 20:40:53,423 - INFO - 1892375343 - Latest checkpoint saved at ./models/checkpoints/ShuffleNetV2_x0_5/checkpoint_latest__Refactored..._samples8000_epochs2.pth
2025-06-01 20:40:53,462 - INFO - 1892375343 - Best model saved at ./models/ShuffleNetV2_x0_5_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:40:53,556 - INFO - 2362377068 - Current learning rate: 0.000500
2025-06-01 20:40:53,611 - INFO - 2362377068 - Final model saved: ./models/ShuffleNetV2_x0_5_final__Refactored..._samples8000_epochs2.pth
2025-06-01 20:40:53,612 - INFO - 2362377068 - Total training time: 118.55s


Testing:   0%|          | 0/94 [00:00<?, ?it/s]

2025-06-01 20:41:00,439 - INFO - 2362377068 - Test Loss: 0.2074, Test Accuracy: 91.50%
2025-06-01 20:41:00,441 - INFO - 2362377068 - --- Enhanced Fine-tuning Model: SqueezeNet1_1 ---
2025-06-01 20:41:00,477 - INFO - 2362377068 - Enabling gradient checkpointing for SqueezeNet1_1
2025-06-01 20:41:00,478 - INFO - 2362377068 - Gradient checkpointing not available for SqueezeNet1_1 with Sequential features
2025-06-01 20:41:00,479 - INFO - 2362377068 - Starting fine-tuning for SqueezeNet1_1...
2025-06-01 20:41:00,482 - INFO - 1892375343 - No checkpoint found at ./models/checkpoints/SqueezeNet1_1/checkpoint_latest__Refactored..._samples8000_epochs2.pth, starting from scratch.
2025-06-01 20:41:00,487 - INFO - 2362377068 - Epoch 1/2
2025-06-01 20:41:00,489 - INFO - 2362377068 - Using gradient accumulation with 2 steps


Training:   0%|          | 0/375 [00:00<?, ?it/s]

2025-06-01 20:41:55,662 - INFO - 2362377068 - Train Loss: 0.5934, Train Accuracy: 72.20%


Validation:   0%|          | 0/75 [00:00<?, ?it/s]

2025-06-01 20:42:01,738 - INFO - 2362377068 - Val Loss: 0.4677, Val Accuracy: 80.00%
2025-06-01 20:42:01,776 - INFO - 2362377068 - New best model found for SqueezeNet1_1! (Val Accuracy: 80.00%)
2025-06-01 20:42:01,809 - INFO - 1892375343 - Latest checkpoint saved at ./models/checkpoints/SqueezeNet1_1/checkpoint_latest__Refactored..._samples8000_epochs2.pth
2025-06-01 20:42:01,818 - INFO - 1892375343 - Best model saved at ./models/SqueezeNet1_1_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:42:01,845 - INFO - 2362377068 - Current learning rate: 0.000500
2025-06-01 20:42:01,846 - INFO - 2362377068 - Epoch 2/2
2025-06-01 20:42:01,847 - INFO - 2362377068 - Using gradient accumulation with 2 steps


Training:   0%|          | 0/375 [00:00<?, ?it/s]

2025-06-01 20:42:56,879 - INFO - 2362377068 - Train Loss: 0.3338, Train Accuracy: 87.06%


Validation:   0%|          | 0/75 [00:00<?, ?it/s]

2025-06-01 20:43:02,194 - INFO - 2362377068 - Val Loss: 0.2574, Val Accuracy: 89.38%
2025-06-01 20:43:02,235 - INFO - 2362377068 - New best model found for SqueezeNet1_1! (Val Accuracy: 89.38%)
2025-06-01 20:43:02,280 - INFO - 1892375343 - Latest checkpoint saved at ./models/checkpoints/SqueezeNet1_1/checkpoint_latest__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:02,295 - INFO - 1892375343 - Best model saved at ./models/SqueezeNet1_1_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:02,335 - INFO - 2362377068 - Current learning rate: 0.000500
2025-06-01 20:43:02,350 - INFO - 2362377068 - Final model saved: ./models/SqueezeNet1_1_final__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:02,350 - INFO - 2362377068 - Total training time: 121.87s


Testing:   0%|          | 0/94 [00:00<?, ?it/s]

2025-06-01 20:43:08,917 - INFO - 2362377068 - Test Loss: 0.2787, Test Accuracy: 88.92%
2025-06-01 20:43:08,919 - INFO - 2362377068 - 
--- Enhanced Fine-tuning Benchmarking Results ---
2025-06-01 20:43:08,919 - INFO - 2362377068 - Model: MobileNetV3-Small, Test Loss: 0.1837, Test Accuracy: 92.65%, Params: 2.54M, Train Time: 117.72, Inference Time: 6.6003, Inference Speed: 909.04
2025-06-01 20:43:08,920 - INFO - 2362377068 - Model: ShuffleNetV2_x0_5, Test Loss: 0.2074, Test Accuracy: 91.50%, Params: 1.37M, Train Time: 118.55, Inference Time: 6.8264, Inference Speed: 878.94
2025-06-01 20:43:08,921 - INFO - 2362377068 - Model: SqueezeNet1_1, Test Loss: 0.2787, Test Accuracy: 88.92%, Params: 1.24M, Train Time: 121.87, Inference Time: 6.5658, Inference Speed: 913.82
2025-06-01 20:43:08,922 - INFO - 2362377068 - Best model(s) based on test accuracy:
2025-06-01 20:43:08,922 - INFO - 2362377068 -  - MobileNetV3-Small: 92.65%
2025-06-01 20:43:08,923 - INFO - 2362377068 -  - ShuffleNetV2_x0_5: 91

[{'model_name': 'MobileNetV3-Small',
  'test_loss': 0.1836968141781787,
  'test_accuracy': 92.65,
  'params': 2.542856,
  'train_time': '117.72',
  'inference_time': '6.6003',
  'inference_speed': '909.04',
  'best_val_accuracy': 93.25},
 {'model_name': 'ShuffleNetV2_x0_5',
  'test_loss': 0.2073926074070235,
  'test_accuracy': 91.5,
  'params': 1.366792,
  'train_time': '118.55',
  'inference_time': '6.8264',
  'inference_speed': '878.94',
  'best_val_accuracy': 92.14583333333333},
 {'model_name': 'SqueezeNet1_1',
  'test_loss': 0.2786646982928117,
  'test_accuracy': 88.91666666666667,
  'params': 1.235496,
  'train_time': '121.87',
  'inference_time': '6.5658',
  'inference_speed': '913.82',
  'best_val_accuracy': 89.375}]

### Save Models with Classes

In [None]:

# Create a mapping from indices to class names for inference
IDX_TO_CLASS = {i: category for i, category in enumerate(QUICKDRAW_CATEGORIES)}

# Inference wrapper that includes class names
class ClassNameInferenceWrapper:
    def __init__(self, model, idx_to_class=None):
        self.model = model
        # Include QUICKDRAW_CATEGORIES for completeness
        self.QUICKDRAW_CATEGORIES = QUICKDRAW_CATEGORIES
        self.idx_to_class = idx_to_class or IDX_TO_CLASS
        self.model.eval()  # Set to evaluation mode

    def predict(self, inputs):
        with torch.no_grad():
            outputs = self.model(inputs)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            _, predicted_indices = torch.max(outputs, 1)

            # Convert to numpy for easier handling
            predicted_indices = predicted_indices.cpu().numpy()
            probabilities = probabilities.cpu().numpy()

            # Map indices to class names
            predicted_names = [self.idx_to_class.get(idx, "Unknown") for idx in predicted_indices]

            return {
                'class_idx': predicted_indices,
                'class_name': predicted_names,
                'probabilities': probabilities
            }

    def predict_single(self, input_tensor):
        # Add batch dimension if needed
        if input_tensor.dim() == 3:
            input_tensor = input_tensor.unsqueeze(0)

        result = self.predict(input_tensor)

        # Return just the first result since it's a single image
        return {
            'class_idx': result['class_idx'][0],
            'class_name': result['class_name'][0],
            'probabilities': result['probabilities'][0]
        }

def save_model_with_classes(model, model_path, class_names):
    """Save model weights along with class information for portability"""
    # Save model weights
    torch.save(model.state_dict(), model_path)

    # Save class information
    class_info_path = model_path.replace('.pth', '_classes.json')
    with open(class_info_path, 'w') as f:
        json.dump(class_names, f)

    logger.info(f"Model saved to {model_path}")
    logger.info(f"Class information saved to {class_info_path}")

    return model_path, class_info_path

### Class-wise Evaluation Metrics

In [None]:
def evaluate_model_by_class(model, test_loader, class_names):
    """Evaluate model performance for each class separately"""
    model.eval()
    class_correct = [0] * len(class_names)
    class_total = [0] * len(class_names)

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating by class"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)

            # Calculate class-wise accuracy
            for i in range(len(labels)):
                label = labels[i].item()
                class_correct[label] += (predicted[i] == label).item()
                class_total[label] += 1

    # logger.info and return results
    logger.info("\nClass-wise Accuracy:")
    class_accuracies = {}
    for i in range(len(class_names)):
        if class_total[i] > 0:
            accuracy = 100 * class_correct[i] / class_total[i]
            logger.info(f"{class_names[i]}: {accuracy:.2f}%")
            class_accuracies[class_names[i]] = accuracy

    return class_accuracies

### Inference

In [None]:
# Load a test image for inference
# Let's get a sample from the test dataset
_, _, test_dataset = get_augmented_quickdraw_data(
    QUICKDRAW_CATEGORIES[:5],  # Using just a few categories for faster loading
    10,  # Small number of samples per category
    5,
    BINARY_DATA_ROOT
)

# Get a batch of test images
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)
images, labels = next(iter(test_loader))

# Use the class name wrapper for inference
# Example with a saved model:
model = models.mobilenet_v3_small(weights=None)
# Set up classifier for the number of classes we have
in_features = model.classifier[0].in_features
model.classifier = nn.Sequential(
    nn.Linear(in_features, 1024),
    nn.BatchNorm1d(1024),
    nn.ReLU(inplace=True),
    nn.Dropout(0.4),
    nn.Linear(1024, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.3),
    nn.Linear(512, len(QUICKDRAW_CATEGORIES))
)

# Try to load the model if it exists
try:
    model.load_state_dict(torch.load(f"{MODEL_SAVE_PATH}/MobileNetV3-Small_best{LATEST_PATH}.pth"))
    logger.info(f"Successfully loaded model from {MODEL_SAVE_PATH}/MobileNetV3-Small_best{LATEST_PATH}.pth")
except Exception as e:
    logger.info(f"Could not load model: {e}. Using untrained model for demonstration.")

model.to(DEVICE)

# Wrap the model
inference_wrapper = ClassNameInferenceWrapper(model)

# Move images to device
images = images.to(DEVICE)

# Make predictions with class names
result = inference_wrapper.predict(images)
logger.info(f"Predicted classes: {result['class_name']}")
logger.info(f"True labels: {[QUICKDRAW_CATEGORIES[label.item()] for label in labels]}")

2025-06-01 20:43:08,958 - INFO - 3059701004 - Loading augmented QuickDraw data for 3 categories...
2025-06-01 20:43:08,960 - INFO - 3059701004 - Loading QuickDraw category: apple from local binary files...
2025-06-01 20:43:08,961 - INFO - 4131335437 - Successfully loaded or indexed 144722 drawings for apple. Cache capacity: 50000 items.
2025-06-01 20:43:08,961 - INFO - 3059701004 - Category apple: 10 train, 2 val, 2000 test
2025-06-01 20:43:09,013 - INFO - 3059701004 - Loading QuickDraw category: cat from local binary files...
2025-06-01 20:43:09,013 - INFO - 4131335437 - Successfully loaded or indexed 123202 drawings for cat. Cache capacity: 50000 items.
2025-06-01 20:43:09,014 - INFO - 3059701004 - Category cat: 10 train, 2 val, 2000 test
2025-06-01 20:43:09,060 - INFO - 3059701004 - Loading QuickDraw category: dog from local binary files...
2025-06-01 20:43:09,061 - INFO - 4131335437 - Successfully loaded or indexed 152159 drawings for dog. Cache capacity: 50000 items.
2025-06-01 20

In [None]:
def save_sample_images(images, labels, predictions, class_names, save_dir="./sample_images"):
    """Save sample images with their true and predicted labels"""
    # Create directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Convert tensor images to PIL for saving
    for i in range(min(5, len(images))):
        # Convert tensor to PIL image
        img_tensor = images[i].cpu()
        img = T.ToPILImage()(img_tensor)

        true_label = class_names[labels[i].item()]
        pred_label = predictions['class_name'][i]

        # Save image with informative filename
        filename = f"{save_dir}/sample_{i+1}_{true_label}_pred_{pred_label}.png"
        img.save(filename)
        logger.info(f"Saved sample image to {filename}")

### Trying Ensemble of Models

In [None]:
import torch.nn.functional as F

class EnhancedEnsembleModel(nn.Module):
    def __init__(self, models, weights=None, device=DEVICE, method='average'):
        super().__init__()
        self.models = models
        # Initialize with equal weights if none provided
        self.weights = weights if weights is not None else [1.0/len(models)] * len(models)
        self.device = device
        self.method = method  # 'average' or 'stack'

        # Register weights as a parameter so they can be optimized
        self.learned_weights = nn.Parameter(torch.tensor(self.weights, device=device))

        # For stacking method, add a meta-classifier
        if self.method == 'stack':
            # Input size: number of classes * number of models
            # Each model produces probabilities for each class
            input_size = len(QUICKDRAW_CATEGORIES) * len(models)
            self.meta_classifier = nn.Sequential(
                nn.Linear(input_size, 512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, len(QUICKDRAW_CATEGORIES))
            )

    def to(self, device):
        """Properly move all components to the specified device"""
        self.device = device
        for model in self.models:
            model.to(device)
        if hasattr(self, 'meta_classifier'):
            self.meta_classifier.to(device)
        return super().to(device)  # This moves the learned_weights parameter

    def eval(self):
        for model in self.models:
            model.eval()
        super().eval()

    def train(self, mode=True):
        for model in self.models:
            model.train(mode)
        super().train(mode)

    def forward(self, x):
        if self.method == 'average':
            return self._forward_average(x)
        elif self.method == 'stack':
            return self._forward_stack(x)
        else:
            raise ValueError(f"Unknown ensemble method: {self.method}")

    def _forward_average(self, x):
        """Weighted averaging of model outputs"""
        # Apply softmax to learned weights
        normalized_weights = F.softmax(self.learned_weights, dim=0)

        # Apply torch.set_grad_enabled based on training mode
        with torch.set_grad_enabled(self.training):
            outputs = []
            for i, model in enumerate(self.models):
                output = model(x)
                # Apply softmax to get probabilities
                probs = F.softmax(output, dim=1)
                # Apply weight for this model
                outputs.append(probs * normalized_weights[i])

            # Sum the weighted outputs
            combined = torch.stack(outputs).sum(dim=0)
            # Convert back to logits for compatibility with CrossEntropyLoss
            return torch.log(combined + 1e-8)

    def _forward_stack(self, x):
        """Stacking method - use a meta-classifier on concatenated model outputs"""
        all_probs = []

        # Get predictions from all models
        for model in self.models:
            with torch.set_grad_enabled(self.training):
                output = model(x)
                probs = F.softmax(output, dim=1)
                all_probs.append(probs)

        # Concatenate all probabilities into a single feature vector
        combined = torch.cat(all_probs, dim=1)

        # Ensure combined tensor is on the same device as the meta-classifier
        # This is the key fix for the device mismatch error
        combined = combined.to(self.device)

        # Feed through meta-classifier
        return self.meta_classifier(combined)

def export_ensemble_for_deployment(ensemble, model_path):
    """Create a portable ensemble model package with all required components"""
    # Get git info
    git_info = get_git_info()

    # Create a dictionary containing all necessary information
    export_data = {
        'state_dict': ensemble.state_dict(),
        'model_weights': [m.state_dict() for m in ensemble.models],
        'learned_weights': ensemble.learned_weights.detach().cpu().numpy().tolist(),
        'class_names': QUICKDRAW_CATEGORIES,
        'git_info': git_info,
        'samples_per_category': NUM_TRAIN_SAMPLES_PER_CATEGORY,
        'finetune_epochs': NUM_FINETUNE_EPOCHS
    }
    # Save to file
    torch.save(export_data, model_path)
    logger.info(f"Portable ensemble saved to {model_path}")

    return model_path



### Using Model Stack Ensembling Approach

In [None]:
def train_ensemble_weights(ensemble, epochs=5):
    """Train the ensemble weights using a small validation set"""
    # Load a small validation set
    _, val_dataset, _ = get_augmented_quickdraw_data(
        QUICKDRAW_CATEGORIES,
        NUM_TRAIN_SAMPLES_PER_CATEGORY // 10,  # Use a smaller subset
        NUM_TEST_SAMPLES_PER_CATEGORY // 5,
        BINARY_DATA_ROOT
    )

    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True, num_workers=2)

    # Only train the ensemble weights, not the individual models
    for model in ensemble.models:
        for param in model.parameters():
            param.requires_grad = False

    # Set ensemble to training mode
    ensemble.train()

    # Use optimizer only for the learned weights
    optimizer = optim.Adam([ensemble.learned_weights], lr=0.01)
    criterion = nn.CrossEntropyLoss()

    logger.info(f"Training ensemble weights for {epochs} epochs...")
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = ensemble(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        # logger.info epoch statistics
        epoch_loss = running_loss / total
        epoch_acc = 100 * correct / total
        logger.info(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%")

        # logger.info current weights
        normalized_weights = F.softmax(ensemble.learned_weights, dim=0).cpu().detach().numpy()
        weight_str = ", ".join([f"{w:.4f}" for w in normalized_weights])
        logger.info(f"Current weights: [{weight_str}]")

    # Switch back to evaluation mode
    ensemble.eval()
    return ensemble

### Ensemble Evaluation

In [None]:
def load_model_wrapper(wrapper_class, model_name):
    """Helper to load a model with proper error handling"""
    try:
        model_path = f"{MODEL_SAVE_PATH}/{model_name}_best{LATEST_PATH}.pth"
        wrapper = wrapper_class()
        wrapper.model.load_state_dict(torch.load(model_path))
        wrapper.model.eval()  # Set to evaluation mode
        logger.info(f"Successfully loaded {model_name} from {model_path}")
        return wrapper
    except Exception as e:
        logger.info(f"Failed to load {model_name}: {e}")
        return None


def evaluate_ensemble(method='average', test_loader=None):
    logger.info(f"Evaluating ensemble using {method} method...")

    # Load the models using the wrappers
    try:

        # Explicitly flush handlers after a significant phase
        for handler in logger.handlers:
            if isinstance(handler, logging.FileHandler):
                handler.flush()

        # Use the helper function to load models
        mobilenet_wrapper = load_model_wrapper(
            MobileNetV3Wrapper, "MobileNetV3-Small")
        squeezenet_wrapper = load_model_wrapper(
            SqueezeNetWrapper, "SqueezeNet1_1")
        shufflenet_wrapper = load_model_wrapper(
            ShuffleNetV2Wrapper, "ShuffleNetV2_x0_5")

        models = []
        initial_weights = []

        # Add models that were successfully loaded
        if mobilenet_wrapper and squeezenet_wrapper:
            models.extend([mobilenet_wrapper.model, squeezenet_wrapper.model])
            initial_weights.extend([0.4, 0.3])

            # Try to add ShuffleNet if available
            if shufflenet_wrapper:
                models.append(shufflenet_wrapper.model)
                initial_weights.append(0.3)
                logger.info(
                    f"Using 3-model ensemble ({method} method) with MobileNet, SqueezeNet, ShuffleNet")
            else:
                # Rebalance weights for 2-model ensemble
                initial_weights = [0.6, 0.4]
                logger.info("Using 2-model ensemble (MobileNet, SqueezeNet)")
        else:
            raise ValueError("Could not load enough models for ensemble")

        # Create enhanced ensemble with the available models
        ensemble = EnhancedEnsembleModel(
            models, weights=initial_weights, method=method)

    except Exception as e:
        logger.info(f"Could not load models: {e}. Cannot create ensemble.")
        return None

    # Move ensemble to device
    ensemble.to(DEVICE)
    ensemble.eval()

    # Load test data if not provided
    if test_loader is None:
        _, _, test_dataset = get_augmented_quickdraw_data(
            QUICKDRAW_CATEGORIES,
            NUM_TRAIN_SAMPLES_PER_CATEGORY,
            NUM_TEST_SAMPLES_PER_CATEGORY,
            BINARY_DATA_ROOT
        )

        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True,
                                 prefetch_factor=3)

    # If using stacking method, train the meta-classifier first
    if method == 'stack':
        logger.info("Training meta-classifier for stacking ensemble...")
        # Create a small dataset for training the meta-classifier
        _, val_dataset, _ = get_augmented_quickdraw_data(
            QUICKDRAW_CATEGORIES,
            NUM_TRAIN_SAMPLES_PER_CATEGORY // 5,  # Use a smaller subset
            NUM_TEST_SAMPLES_PER_CATEGORY // 5,
            BINARY_DATA_ROOT
        )

        val_loader = DataLoader(
            val_dataset, batch_size=64, shuffle=True, num_workers=2,
            prefetch_factor=3)

        # Freeze base models
        for model in ensemble.models:
            for param in model.parameters():
                param.requires_grad = False

        # Only train meta-classifier
        optimizer = optim.Adam(ensemble.meta_classifier.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        # Train for a few epochs
        ensemble.train()
        for epoch in range(5):

            # Explicitly flush handlers after a significant phase
            for handler in logger.handlers:
                if isinstance(handler, logging.FileHandler):
                    handler.flush()

            running_loss = 0.0
            correct = 0
            total = 0

            for images, labels in tqdm(val_loader, desc=f"Meta-classifier epoch {epoch+1}/5"):
                images, labels = images.to(DEVICE), labels.to(DEVICE)

                optimizer.zero_grad()
                outputs = ensemble(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

            epoch_loss = running_loss / total
            epoch_acc = 100 * correct / total
            logger.info(f"Meta-classifier Epoch {epoch+1}/5 - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%")

        ensemble.eval()
        clear_gpu_memory()  # Add explicit memory cleanup after training

    # For the averaging method, train the weights
    elif method == 'average':
        logger.info("Training ensemble weights for averaging method...")
        ensemble = train_ensemble_weights(ensemble, epochs=3)
        clear_gpu_memory()  # Add explicit memory cleanup after training

    # Evaluate
    model_correct = {
        "mobilenet": 0,
        "squeezenet": 0,
        "shufflenet": 0 if len(models) > 2 else None,
        "ensemble": 0
    }
    total = 0

    logger.info("Evaluating ensemble vs individual models...")
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating models"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            # Get predictions from individual models
            mobilenet_outputs = mobilenet_wrapper.model(images)
            squeezenet_outputs = squeezenet_wrapper.model(images)

            # Get ensemble prediction
            ensemble_outputs = ensemble(images)

            # Calculate accuracy for each model
            _, mobilenet_preds = torch.max(mobilenet_outputs, 1)
            _, squeezenet_preds = torch.max(squeezenet_outputs, 1)
            _, ensemble_preds = torch.max(ensemble_outputs, 1)

            # Update correct counts
            model_correct["mobilenet"] += (mobilenet_preds ==
                                           labels).sum().item()
            model_correct["squeezenet"] += (squeezenet_preds ==
                                            labels).sum().item()

            # Only evaluate ShuffleNet if it's part of the ensemble
            if len(models) > 2 and shufflenet_wrapper:
                shufflenet_outputs = shufflenet_wrapper.model(images)
                _, shufflenet_preds = torch.max(shufflenet_outputs, 1)
                model_correct["shufflenet"] += (shufflenet_preds ==
                                                labels).sum().item()

            model_correct["ensemble"] += (ensemble_preds ==
                                          labels).sum().item()

            total += labels.size(0)

    # Calculate and logger.info accuracies
    logger.info(f"\n--- Model Accuracy Comparison ({method} method) ---")
    mobilenet_acc = 100 * model_correct["mobilenet"] / total
    logger.info(f"MobileNetV3:  {mobilenet_acc:.2f}%")

    squeezenet_acc = 100 * model_correct["squeezenet"] / total
    logger.info(f"SqueezeNet:   {squeezenet_acc:.2f}%")

    # Only logger.info ShuffleNet accuracy if it's part of the ensemble
    if model_correct["shufflenet"] is not None:
        shufflenet_acc = 100 * model_correct["shufflenet"] / total
        logger.info(f"ShuffleNetV2: {shufflenet_acc:.2f}%")
        best_single = max(mobilenet_acc, squeezenet_acc, shufflenet_acc)
    else:
        best_single = max(mobilenet_acc, squeezenet_acc)

    ensemble_acc = 100 * model_correct["ensemble"] / total
    logger.info(f"Ensemble ({method}): {ensemble_acc:.2f}%")

    # Calculate improvement
    improvement = ensemble_acc - best_single
    logger.info(
        f"\nEnsemble improves accuracy by {improvement:.2f}% over the best single model")

    # For averaging method, logger.info the learned weights
    if method == 'average':
        normalized_weights = F.softmax(
            ensemble.learned_weights, dim=0).cpu().detach().numpy()
        logger.info("\nLearned model weights in ensemble:")
        logger.info(f"MobileNetV3:  {normalized_weights[0]:.4f}")
        logger.info(f"SqueezeNet:   {normalized_weights[1]:.4f}")
        if len(models) > 2:
            logger.info(f"ShuffleNetV2: {normalized_weights[2]:.4f}")

    # Save the ensemble model with error handling
    try:
        ensemble_path = f"{MODEL_SAVE_PATH}/ensemble_model_{method}{LATEST_PATH}.pth"

        # Use the export function for a fully portable model with error handling
        try:
            export_path = export_ensemble_for_deployment(ensemble, ensemble_path)
            logger.info(f"Ensemble model exported to {export_path}")
        except Exception as e:
            logger.error(f"Failed to export ensemble: {e}")
            # Try a simpler export approach
            torch.save(ensemble.state_dict(), ensemble_path)
            logger.info(f"Saved ensemble state dict to {ensemble_path} (fallback method)")

        # Also evaluate class-wise performance
        logger.info("\nEvaluating class-wise performance...")
        class_accuracies = evaluate_model_by_class(
            ensemble, test_loader, QUICKDRAW_CATEGORIES)

        # Save class accuracies with compression
        class_acc_path = f"{MODEL_SAVE_PATH}/ensemble_{method}_class_accuracies{LATEST_PATH}.json.gz"
        with gzip.open(class_acc_path, 'wt') as f:
            json.dump(class_accuracies, f, indent=2)
        logger.info(f"Class accuracies saved to {class_acc_path}")
    except Exception as e:
        logger.info(f"Could not save ensemble model: {e}")

    # Format results for saving
    ensemble_results = [
        {
            'model_name': 'MobileNetV3',
            'accuracy': mobilenet_acc,
            'ensemble_method': method
        },
        {
            'model_name': 'SqueezeNet',
            'accuracy': squeezenet_acc,
            'ensemble_method': method
        },
        {
            'model_name': f'Ensemble ({method})',
            'accuracy': ensemble_acc,
            'improvement_over_best': improvement,
            'ensemble_method': method
        }
    ]

    # Add ShuffleNet if it was used
    if model_correct["shufflenet"] is not None:
        ensemble_results.insert(2, {
            'model_name': 'ShuffleNetV2',
            'accuracy': shufflenet_acc,
            'ensemble_method': method
        })

    # Save results
    save_benchmark_results(ensemble_results, benchmark_type=f'ensemble_{method}')
    # Return ensemble for later use
    return ensemble

In [None]:
# Run ensemble evaluation with averaging method
logger.info("="*50)
logger.info("EVALUATING ENSEMBLE WITH AVERAGING METHOD")
logger.info("="*50)
averaging_ensemble = evaluate_ensemble(method='average')

# Run ensemble evaluation with stacking method
logger.info("="*50)
logger.info("EVALUATING ENSEMBLE WITH STACKING METHOD")
logger.info("="*50)
stacking_ensemble = evaluate_ensemble(method='stack')

# Compare the results
logger.info("="*50)
logger.info("ENSEMBLE METHOD COMPARISON")
logger.info("="*50)
logger.info("If both methods ran successfully, you can compare their performance.")
logger.info("The stacking method usually performs better when there are complementary strengths")
logger.info("in the base models, while averaging is more robust to overfitting.")

# Save sample inference images
_, _, test_dataset = get_augmented_quickdraw_data(
    QUICKDRAW_CATEGORIES[:10],  # Using just first 10 categories for faster loading
    10,  # Small number of samples per category
    5,
    BINARY_DATA_ROOT
)

# Create a small test loader
test_loader = DataLoader(test_dataset, batch_size=5, shuffle=True)

# Get a batch of test images
try:
    images, labels = next(iter(test_loader))
    images = images.to(DEVICE)

    # Create inference wrapper for the best ensemble
    best_ensemble = averaging_ensemble  # You can change this to stacking_ensemble if it performs better
    if best_ensemble is not None:
        inference_wrapper = ClassNameInferenceWrapper(best_ensemble, IDX_TO_CLASS)

        # Make predictions
        predictions = inference_wrapper.predict(images)

        # Save sample images
        save_sample_images(
            images,
            labels,
            predictions,
            QUICKDRAW_CATEGORIES,
            save_dir="./sample_ensemble_predictions"
        )
except Exception as e:
    logger.info(f"Error saving sample images: {e}")


# Explicitly flush handlers after a significant phase
for handler in logger.handlers:
    if isinstance(handler, logging.FileHandler):
        handler.flush()


2025-06-01 20:43:09,331 - INFO - 2882100689 - EVALUATING ENSEMBLE WITH AVERAGING METHOD
2025-06-01 20:43:09,332 - INFO - 2657814370 - Evaluating ensemble using average method...
2025-06-01 20:43:09,430 - INFO - 2657814370 - Successfully loaded MobileNetV3-Small from ./models/MobileNetV3-Small_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:09,459 - INFO - 2657814370 - Successfully loaded SqueezeNet1_1 from ./models/SqueezeNet1_1_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:09,537 - INFO - 2657814370 - Successfully loaded ShuffleNetV2_x0_5 from ./models/ShuffleNetV2_x0_5_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:09,538 - INFO - 2657814370 - Using 3-model ensemble (average method) with MobileNet, SqueezeNet, ShuffleNet
2025-06-01 20:43:09,561 - INFO - 3059701004 - Loading augmented QuickDraw data for 3 categories...
2025-06-01 20:43:09,561 - INFO - 3059701004 - Loading QuickDraw category: apple from local binary files...
2025-06-01 20:43:09

Epoch 1/3:   0%|          | 0/8 [00:00<?, ?it/s]

2025-06-01 20:43:10,972 - INFO - 2433811703 - Epoch 1/3 - Loss: 0.1871, Acc: 92.29%
2025-06-01 20:43:10,973 - INFO - 2433811703 - Current weights: [0.3821, 0.2996, 0.3182]


Epoch 2/3:   0%|          | 0/8 [00:00<?, ?it/s]

2025-06-01 20:43:11,991 - INFO - 2433811703 - Epoch 2/3 - Loss: 0.1922, Acc: 93.12%
2025-06-01 20:43:11,992 - INFO - 2433811703 - Current weights: [0.4057, 0.2790, 0.3153]


Epoch 3/3:   0%|          | 0/8 [00:00<?, ?it/s]

2025-06-01 20:43:13,006 - INFO - 2433811703 - Epoch 3/3 - Loss: 0.1836, Acc: 92.71%
2025-06-01 20:43:13,007 - INFO - 2433811703 - Current weights: [0.4325, 0.2611, 0.3064]
2025-06-01 20:43:13,021 - INFO - 2657814370 - Evaluating ensemble vs individual models...


Evaluating models:   0%|          | 0/94 [00:00<?, ?it/s]

2025-06-01 20:43:26,983 - INFO - 2657814370 - 
--- Model Accuracy Comparison (average method) ---
2025-06-01 20:43:26,984 - INFO - 2657814370 - MobileNetV3:  93.47%
2025-06-01 20:43:26,985 - INFO - 2657814370 - SqueezeNet:   89.02%
2025-06-01 20:43:26,986 - INFO - 2657814370 - ShuffleNetV2: 92.45%
2025-06-01 20:43:26,986 - INFO - 2657814370 - Ensemble (average): 93.43%
2025-06-01 20:43:26,987 - INFO - 2657814370 - 
Ensemble improves accuracy by -0.03% over the best single model
2025-06-01 20:43:26,988 - INFO - 2657814370 - 
Learned model weights in ensemble:
2025-06-01 20:43:26,988 - INFO - 2657814370 - MobileNetV3:  0.4325
2025-06-01 20:43:26,989 - INFO - 2657814370 - SqueezeNet:   0.2611
2025-06-01 20:43:26,989 - INFO - 2657814370 - ShuffleNetV2: 0.3064
2025-06-01 20:43:27,062 - INFO - 4195590199 - Portable ensemble saved to ./models/ensemble_model_average__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:27,064 - INFO - 2657814370 - Ensemble model exported to ./models/ensemble

Evaluating by class:   0%|          | 0/94 [00:00<?, ?it/s]

2025-06-01 20:43:35,429 - INFO - 1837987104 - 
Class-wise Accuracy:
2025-06-01 20:43:35,430 - INFO - 1837987104 - apple: 98.65%
2025-06-01 20:43:35,431 - INFO - 1837987104 - cat: 90.40%
2025-06-01 20:43:35,431 - INFO - 1837987104 - dog: 91.25%
2025-06-01 20:43:35,433 - INFO - 2657814370 - Class accuracies saved to ./models/ensemble_average_class_accuracies__Refactored..._samples8000_epochs2.json.gz
2025-06-01 20:43:35,434 - INFO - 392765423 - Benchmark results saved to ./results/ensemble_average__Refactored..._samples8000_epochs2.json and ./results/ensemble_average__Refactored..._samples8000_epochs2.csv
2025-06-01 20:43:35,436 - INFO - 2882100689 - EVALUATING ENSEMBLE WITH STACKING METHOD
2025-06-01 20:43:35,437 - INFO - 2657814370 - Evaluating ensemble using stack method...
2025-06-01 20:43:35,530 - INFO - 2657814370 - Successfully loaded MobileNetV3-Small from ./models/MobileNetV3-Small_best__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:35,559 - INFO - 2657814370 - Successf

Meta-classifier epoch 1/5:   0%|          | 0/15 [00:00<?, ?it/s]

2025-06-01 20:43:37,855 - INFO - 2657814370 - Meta-classifier Epoch 1/5 - Loss: 0.8134, Acc: 76.56%


Meta-classifier epoch 2/5:   0%|          | 0/15 [00:00<?, ?it/s]

2025-06-01 20:43:39,608 - INFO - 2657814370 - Meta-classifier Epoch 2/5 - Loss: 0.3777, Acc: 93.23%


Meta-classifier epoch 3/5:   0%|          | 0/15 [00:00<?, ?it/s]

2025-06-01 20:43:41,379 - INFO - 2657814370 - Meta-classifier Epoch 3/5 - Loss: 0.2385, Acc: 93.23%


Meta-classifier epoch 4/5:   0%|          | 0/15 [00:00<?, ?it/s]

2025-06-01 20:43:43,180 - INFO - 2657814370 - Meta-classifier Epoch 4/5 - Loss: 0.2018, Acc: 93.33%


Meta-classifier epoch 5/5:   0%|          | 0/15 [00:00<?, ?it/s]

2025-06-01 20:43:44,940 - INFO - 2657814370 - Meta-classifier Epoch 5/5 - Loss: 0.2006, Acc: 92.81%
2025-06-01 20:43:44,950 - INFO - 2657814370 - Evaluating ensemble vs individual models...


Evaluating models:   0%|          | 0/94 [00:00<?, ?it/s]

2025-06-01 20:43:58,829 - INFO - 2657814370 - 
--- Model Accuracy Comparison (stack method) ---
2025-06-01 20:43:58,830 - INFO - 2657814370 - MobileNetV3:  92.95%
2025-06-01 20:43:58,831 - INFO - 2657814370 - SqueezeNet:   88.92%
2025-06-01 20:43:58,831 - INFO - 2657814370 - ShuffleNetV2: 92.08%
2025-06-01 20:43:58,832 - INFO - 2657814370 - Ensemble (stack): 92.92%
2025-06-01 20:43:58,833 - INFO - 2657814370 - 
Ensemble improves accuracy by -0.03% over the best single model
2025-06-01 20:43:58,900 - INFO - 4195590199 - Portable ensemble saved to ./models/ensemble_model_stack__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:58,901 - INFO - 2657814370 - Ensemble model exported to ./models/ensemble_model_stack__Refactored..._samples8000_epochs2.pth
2025-06-01 20:43:58,902 - INFO - 2657814370 - 
Evaluating class-wise performance...


Evaluating by class:   0%|          | 0/94 [00:00<?, ?it/s]

2025-06-01 20:44:07,551 - INFO - 1837987104 - 
Class-wise Accuracy:
2025-06-01 20:44:07,552 - INFO - 1837987104 - apple: 97.85%
2025-06-01 20:44:07,552 - INFO - 1837987104 - cat: 89.80%
2025-06-01 20:44:07,553 - INFO - 1837987104 - dog: 91.10%
2025-06-01 20:44:07,555 - INFO - 2657814370 - Class accuracies saved to ./models/ensemble_stack_class_accuracies__Refactored..._samples8000_epochs2.json.gz
2025-06-01 20:44:07,556 - INFO - 392765423 - Benchmark results saved to ./results/ensemble_stack__Refactored..._samples8000_epochs2.json and ./results/ensemble_stack__Refactored..._samples8000_epochs2.csv
2025-06-01 20:44:07,557 - INFO - 2882100689 - ENSEMBLE METHOD COMPARISON
2025-06-01 20:44:07,559 - INFO - 2882100689 - If both methods ran successfully, you can compare their performance.
2025-06-01 20:44:07,559 - INFO - 2882100689 - The stacking method usually performs better when there are complementary strengths
2025-06-01 20:44:07,560 - INFO - 2882100689 - in the base models, while averag

# Task
Modify the `MultimodalEmbeddingWrapper` class to save generated image embeddings to a file named based on the number of categories and images per category. If the file exists, load embeddings from the file instead of regenerating them.

## Modify multimodalembeddingwrapper initialization

### Subtask:
Add attributes to the `MultimodalEmbeddingWrapper` class to store the number of categories and images per category used for generating embeddings.


**Reasoning**:
The subtask is to modify the `MultimodalEmbeddingWrapper` class to accept and store `num_categories` and `num_images_per_category` attributes. I will modify the `__init__` method of the class to include these arguments and store them as instance attributes.



**Reasoning**:
The previous command failed because `ModelWrapper` was not defined in the code block. I need to include the definition of `ModelWrapper` and the other model wrapper classes in this code block, as they are dependencies for `MultimodalEmbeddingWrapper`.



In [17]:
from google.cloud import aiplatform
from vertexai.vision_models import Image as VertexAiImage, MultiModalEmbeddingModel
from PIL import Image
from io import BytesIO
import pickle # Import pickle for saving/loading embeddings
import os # Import os for file path operations

class MultimodalEmbeddingWrapper(ModelWrapper):
    def __init__(self, project_id, location, num_classes=len(QUICKDRAW_CATEGORIES), categories=QUICKDRAW_CATEGORIES, num_categories=len(QUICKDRAW_CATEGORIES), num_images_per_category=NUM_TRAIN_SAMPLES_PER_CATEGORY):
        super().__init__(None, num_classes)
        self.project_id = project_id
        self.location = location
        self.categories = categories
        # Add the new attributes as per the plan
        self.num_categories = num_categories
        self.num_images_per_category = num_images_per_category


        # Initialize Vertex AI and load the model
        try:
            aiplatform.init(project=self.project_id, location=self.location)
            self.model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
            logger.info("Successfully initialized Vertex AI and loaded MultimodalEmbeddingModel.")
        except Exception as e:
            logger.error(f"Failed to initialize Vertex AI or load MultimodalEmbeddingModel: {e}")
            self.model = None # Ensure model is None if loading fails

        # Pre-compute text embeddings for categories
        self.category_text_embeddings = self._get_category_text_embeddings()

        # Add a linear layer to classify image embeddings based on text embeddings
        # The MultimodalEmbeddingModel outputs a fixed dimension (1408)
        if self.category_text_embeddings is not None and self.model is not None:
             # Compute the average text embedding to get the dimension
             avg_embedding = np.mean(self.category_text_embeddings, axis=0)
             input_dim = len(avg_embedding) # Should be 1408
             # Use cosine similarity for classification
             # The text embeddings act as the "weights" for classification
             # We don't need a separate linear layer for this approach
             self.classifier = None # No traditional classifier needed for this method
             logger.info(f"Initialized MultimodalEmbeddingWrapper for similarity-based classification.")
        else:
             self.classifier = None
             logger.warning("Could not initialize MultimodalEmbeddingWrapper for similarity-based classification.")


    def _get_category_text_embeddings(self):
        """Generates text embeddings for each category name."""
        if not self.model:
            logger.warning("MultimodalEmbeddingModel not loaded, cannot get text embeddings.")
            return None

        text_embeddings = []
        logger.info("Generating text embeddings for categories...")
        try:
            for category in tqdm(self.categories, desc="Generating text embeddings"):
                 # The context_text is optional but can improve quality for ambiguous terms
                 embeddings = self.model.get_embeddings(
                     contextual_text=category,
                     dimension=1408 # A required, fixed dimension size for this model
                 )
                 if embeddings and embeddings.text_embedding:
                     text_embeddings.append(embeddings.text_embedding)
                 else:
                     logger.warning(f"Could not get text embedding for category: {category}")

            if not text_embeddings:
                 logger.warning("No text embeddings were generated.")
                 return None

            return np.array(text_embeddings) # Convert to numpy array
        except Exception as e:
             logger.error(f"Error generating text embeddings: {e}")
             return None

    def calculate_cosine_similarity(self, vec1, vec2):
        """Calculates the cosine similarity between two vectors."""
        # Ensure inputs are numpy arrays for dot product and norm
        vec1 = np.asarray(vec1)
        vec2 = np.asarray(vec2)
        # Add a small epsilon to the denominator to avoid division by zero
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)

    def forward(self, images):
        """
        Generates image embeddings and classifies them using cosine similarity
        with pre-computed text embeddings. Checks for and loads saved embeddings first.
        """
        if self.model is None or self.category_text_embeddings is None:
            logger.error("MultimodalEmbeddingModel or text embeddings not initialized.")
            # Return dummy output
            return torch.zeros(images.size(0), self.num_classes, device=images.device)

        # Generate the expected embeddings filename
        embeddings_file = get_embedding_filename(self.num_categories, self.num_images_per_category)

        # Check if the embeddings file exists and load from it
        if os.path.exists(embeddings_file):
            logger.info(f"Loading image embeddings from {embeddings_file}")
            try:
                with open(embeddings_file, 'rb') as f:
                    all_image_embeddings = pickle.load(f)
                logger.info(f"Successfully loaded {len(all_image_embeddings)} image embeddings.")

                # Convert loaded embeddings to torch tensors and move to device
                all_image_embeddings_tensors = [torch.tensor(emb, dtype=torch.float32).to(images.device) for emb in all_image_embeddings]

                # Stack the individual image embedding tensors
                stacked_image_embeddings = torch.stack(all_image_embeddings_tensors)

                # Calculate similarity scores with text embeddings
                # Ensure text embeddings are also on the device for tensor operations
                text_embeddings_tensor = torch.tensor(self.category_text_embeddings, dtype=torch.float32).to(images.device)

                # Calculate cosine similarity using tensor operations for GPU acceleration
                # Normalize embeddings
                image_embeddings_normalized = F.normalize(stacked_image_embeddings, p=2, dim=1)
                text_embeddings_normalized = F.normalize(text_embeddings_tensor, p=2, dim=1)

                # Calculate cosine similarity (dot product of normalized vectors)
                # Result will be a tensor of shape (batch_size, num_classes)
                logits = torch.matmul(image_embeddings_normalized, text_embeddings_normalized.transpose(0, 1))

                # Crucially, ensure the batch size of logits matches the input batch size
                if logits.size(0) != images.size(0):
                     logger.error(f"Mismatch between loaded embeddings ({logits.size(0)}) and input batch size ({images.size(0)}). Returning dummy output.")
                     return torch.zeros(images.size(0), self.num_classes, device=images.device)


                return logits # Return logits calculated from loaded embeddings

            except Exception as e:
                logger.error(f"Error loading or processing embeddings from {embeddings_file}: {e}. Proceeding to generate embeddings.")
                # Continue to generate embeddings if loading fails

        # If file doesn't exist or loading failed, generate embeddings
        logger.info("Embeddings file not found or failed to load. Generating new image embeddings...")
        # The MultimodalEmbeddingModel expects VertexAiImage objects initialized with bytes
        # Convert torch tensor batch to list of PIL Images, then to bytes, then to VertexAiImage
        vertex_ai_images = []
        for img_tensor in images:
             # Permute dimensions from C, H, W to H, W, C for numpy conversion
             # Convert to numpy array and scale to 0-255
             # Ensure the tensor is on CPU before converting to numpy/PIL
             img_np = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
             # Convert numpy array to PIL Image
             pil_image = Image.fromarray(img_np)
             # Convert PIL image to bytes
             buffer = BytesIO()
             # Assuming the input images are grayscale sketches, save as PNG
             pil_image.save(buffer, format="PNG")
             image_bytes = buffer.getvalue()
             # Create VertexAiImage object from bytes
             vertex_ai_images.append(VertexAiImage(image_bytes=image_bytes))


        # Get image embeddings from the MultimodalEmbeddingModel
        try:
            # Process images one by one as a workaround for the batch issue
            all_image_embeddings = []
            # Use tqdm for progress tracking
            for i in tqdm(range(len(vertex_ai_images)), desc="Generating image embeddings", leave=False): # Changed desc for clarity
                img_ai = vertex_ai_images[i]
                embeddings_response = self.model.get_embeddings(
                    image=img_ai, # Pass single VertexAiImage object
                    dimension=1408
                )

                if embeddings_response and embeddings_response.image_embedding:
                    # Store the numpy embedding
                    all_image_embeddings.append(embeddings_response.image_embedding)
                else:
                    logger.warning(f"Could not get image embedding for image at index {i}. Skipping.")
                    # Append a zero vector (numpy) to maintain batch size consistency
                    all_image_embeddings.append(np.zeros(1408)) # Append zero vector of dimension 1408


            if not all_image_embeddings:
                 logger.warning("No image embeddings were generated.")
                 # Return dummy output if no embeddings were generated
                 return torch.zeros(images.size(0), self.num_classes, device=images.device)

            # Save the generated embeddings to file
            try:
                with open(embeddings_file, 'wb') as f:
                    pickle.dump(all_image_embeddings, f)
                logger.info(f"Successfully saved generated image embeddings to {embeddings_file}")
            except Exception as e:
                logger.error(f"Error saving generated embeddings to {embeddings_file}: {e}")

            # Convert generated embeddings (which are numpy) to torch tensors and move to device for classification
            all_image_embeddings_tensors = [torch.tensor(emb, dtype=torch.float32).to(images.device) for emb in all_image_embeddings]


            # Stack the individual image embedding tensors
            stacked_image_embeddings = torch.stack(all_image_embeddings_tensors)

            # Calculate similarity scores for each image embedding with all text embeddings
            # Ensure text embeddings are also on the device for tensor operations
            text_embeddings_tensor = torch.tensor(self.category_text_embeddings, dtype=torch.float32).to(images.device)

            # Calculate cosine similarity using tensor operations for GPU acceleration
            # Normalize embeddings
            image_embeddings_normalized = F.normalize(stacked_image_embeddings, p=2, dim=1)
            text_embeddings_normalized = F.normalize(text_embeddings_tensor, p=2, dim=1)

            # Calculate cosine similarity (dot product of normalized vectors)
            # Result will be a tensor of shape (batch_size, num_classes)
            logits = torch.matmul(image_embeddings_normalized, text_embeddings_normalized.transpose(0, 1))

            # Crucially, ensure the batch size of logits matches the input batch size
            if logits.size(0) != images.size(0):
                 logger.error(f"Mismatch between generated embeddings ({logits.size(0)}) and input batch size ({images.size(0)}). Returning dummy output.")
                 return torch.zeros(images.size(0), self.num_classes, device=images.device)


            # Logits are already on the desired device from the matrix multiplication
            return logits

        except Exception as e:
            logger.error(f"Error generating image embeddings or classifying: {e}")
            # Return dummy output
            return torch.zeros(images.size(0), self.num_classes, device=images.device)

    # MultimodalEmbeddingModel does not have state_dict in the traditional PyTorch sense
    # Override state_dict and load_state_dict to handle the classifier (which is now None)
    # and text embeddings if we were to save/load them.
    def state_dict(self):
        # We can save the text embeddings if needed, but for this approach, they are
        # re-generated during initialization based on category names.
        # Saving the state_dict isn't really applicable for this type of model.
        # Returning an empty dict or raising an error might be appropriate.
        # For compatibility with the existing framework, return a minimal dict.
        return {}

    def load_state_dict(self, state_dict):
        # No state to load for this model using this method.
        logger.warning("load_state_dict is not applicable for MultimodalEmbeddingWrapper in similarity mode.")


    @classmethod
    def load_from_checkpoint(cls, model_path, project_id, location, device=DEVICE):
        """Loading from checkpoint is not applicable for this wrapper."""
        logger.warning("load_from_checkpoint is not applicable for MultimodalEmbeddingWrapper in similarity mode.")
        # Return a new instance initialized from scratch
        return cls(project_id=project_id, location=location)

In [16]:
import os

def get_embedding_filename(num_categories, num_images_per_category, base_dir="./embeddings"):
    """
    Generates a filename for the image embeddings file based on the number of categories
    and images per category. Ensures the directory exists.
    """
    # Create base directory if it doesn't exist
    os.makedirs(base_dir, exist_ok=True)
    filename = f"quickdraw_image_embeddings_categories{num_categories}_samples{num_images_per_category}.pkl"
    return os.path.join(base_dir, filename)

# Example usage:
# embeddings_file = get_embedding_filename(len(QUICKDRAW_CATEGORIES), NUM_TRAIN_SAMPLES_PER_CATEGORY)
# logger.info(f"Generated embedding filename: {embeddings_file}")

## BackUp


In [13]:
from google.cloud import aiplatform
# Corrected import: Import Image from vertexai.vision_models
from vertexai.vision_models import Image as VertexAiImage, MultiModalEmbeddingModel
# Import Image from PIL and BytesIO for converting PIL to bytes
from PIL import Image
from io import BytesIO

class MultimodalEmbeddingWrapper(ModelWrapper):
    def __init__(self, project_id, location, num_classes=len(QUICKDRAW_CATEGORIES), categories=QUICKDRAW_CATEGORIES, num_categories=len(QUICKDRAW_CATEGORIES), num_images_per_category=NUM_TRAIN_SAMPLES_PER_CATEGORY):
        super().__init__(None, num_classes)
        self.project_id = project_id
        self.location = location
        self.categories = categories
        # Add the new attributes
        self.num_categories = num_categories
        self.num_images_per_category = num_images_per_category


        # Initialize Vertex AI and load the model
        try:
            aiplatform.init(project=self.project_id, location=self.location)
            self.model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001")
            logger.info("Successfully initialized Vertex AI and loaded MultimodalEmbeddingModel.")
        except Exception as e:
            logger.error(f"Failed to initialize Vertex AI or load MultimodalEmbeddingModel: {e}")
            self.model = None # Ensure model is None if loading fails

        # Pre-compute text embeddings for categories
        self.category_text_embeddings = self._get_category_text_embeddings()

        # Add a linear layer to classify image embeddings based on text embeddings
        # The MultimodalEmbeddingModel outputs a fixed dimension (1408)
        if self.category_text_embeddings is not None and self.model is not None:
             # Compute the average text embedding to get the dimension
             avg_embedding = np.mean(self.category_text_embeddings, axis=0)
             input_dim = len(avg_embedding) # Should be 1408
             # Use cosine similarity for classification
             # The text embeddings act as the "weights" for classification
             # We don't need a separate linear layer for this approach
             self.classifier = None # No traditional classifier needed for this method
             logger.info(f"Initialized MultimodalEmbeddingWrapper for similarity-based classification.")
        else:
             self.classifier = None
             logger.warning("Could not initialize MultimodalEmbeddingWrapper for similarity-based classification.")


    def _get_category_text_embeddings(self):
        """Generates text embeddings for each category name."""
        if not self.model:
            logger.warning("MultimodalEmbeddingModel not loaded, cannot get text embeddings.")
            return None

        text_embeddings = []
        logger.info("Generating text embeddings for categories...")
        try:
            for category in tqdm(self.categories, desc="Generating text embeddings"):
                 # The context_text is optional but can improve quality for ambiguous terms
                 embeddings = self.model.get_embeddings(
                     contextual_text=category,
                     dimension=1408 # A required, fixed dimension size for this model
                 )
                 if embeddings and embeddings.text_embedding:
                     text_embeddings.append(embeddings.text_embedding)
                 else:
                     logger.warning(f"Could not get text embedding for category: {category}")

            if not text_embeddings:
                 logger.warning("No text embeddings were generated.")
                 return None

            return np.array(text_embeddings) # Convert to numpy array
        except Exception as e:
             logger.error(f"Error generating text embeddings: {e}")
             return None

    def calculate_cosine_similarity(self, vec1, vec2):
        """Calculates the cosine similarity between two vectors."""
        # Ensure inputs are numpy arrays for dot product and norm
        vec1 = np.asarray(vec1)
        vec2 = np.asarray(vec2)
        # Add a small epsilon to the denominator to avoid division by zero
        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2) + 1e-8)

    def forward(self, images):
        """
        Generates image embeddings and classifies them using cosine similarity
        with pre-computed text embeddings.
        """
        if self.model is None or self.category_text_embeddings is None:
            logger.error("MultimodalEmbeddingModel or text embeddings not initialized.")
            # Return dummy output
            return torch.zeros(images.size(0), self.num_classes, device=images.device)

        # The MultimodalEmbeddingModel expects VertexAiImage objects initialized with bytes
        # Convert torch tensor batch to list of PIL Images, then to bytes, then to VertexAiImage
        vertex_ai_images = []
        for img_tensor in images:
             # Permute dimensions from C, H, W to H, W, C for numpy conversion
             # Convert to numpy array and scale to 0-255
             img_np = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
             # Convert numpy array to PIL Image
             pil_image = Image.fromarray(img_np)
             # Convert PIL image to bytes
             buffer = BytesIO()
             # Assuming the input images are grayscale sketches, save as PNG
             pil_image.save(buffer, format="PNG")
             image_bytes = buffer.getvalue()
             # Create VertexAiImage object from bytes
             vertex_ai_images.append(VertexAiImage(image_bytes=image_bytes))


        # Get image embeddings from the MultimodalEmbeddingModel
        try:
            # Process images one by one as a workaround for the batch issue
            all_image_embeddings = []
            # Use tqdm for progress tracking
            for i in tqdm(range(len(vertex_ai_images)), desc="Generating image embeddings and classifying", leave=False):
                img_ai = vertex_ai_images[i]
                embeddings_response = self.model.get_embeddings(
                    image=img_ai, # Pass single VertexAiImage object
                    dimension=1408
                )

                if embeddings_response and embeddings_response.image_embedding:
                    all_image_embeddings.append(embeddings_response.image_embedding)
                else:
                    logger.warning(f"Could not get image embedding for image at index {i}. Skipping.")
                    # Append a zero vector or handle appropriately if an embedding is missing
                    # For now, we'll skip and this might cause a shape mismatch later if not all embeddings are generated.
                    # A better approach might be to append a zero vector of the expected dimension.
                    # Let's append a zero vector for now to maintain batch size consistency in the output.
                    all_image_embeddings.append(np.zeros(1408)) # Append zero vector of dimension 1408


            if not all_image_embeddings:
                 logger.warning("No image embeddings were generated.")
                 # Return dummy output if no embeddings were generated
                 return torch.zeros(images.size(0), self.num_classes, device=images.device)


            # Calculate similarity scores for each image embedding with all text embeddings
            batch_similarity_scores = []
            for img_embedding in all_image_embeddings:
                similarity_scores = [self.calculate_cosine_similarity(img_embedding, text_emb) for text_emb in self.category_text_embeddings]
                batch_similarity_scores.append(similarity_scores)

            # Convert similarity scores to logits (higher similarity = higher logit)
            logits = torch.tensor(np.array(batch_similarity_scores), dtype=torch.float32)

            # Move logits to the desired device
            logits = logits.to(images.device)

            return logits

        except Exception as e:
            logger.error(f"Error generating image embeddings or classifying: {e}")
            # Return dummy output
            return torch.zeros(images.size(0), self.num_classes, device=images.device)

    # MultimodalEmbeddingModel does not have state_dict in the traditional PyTorch sense
    # Override state_dict and load_state_dict to handle the classifier (which is now None)
    # and text embeddings if we were to save/load them.
    def state_dict(self):
        # We can save the text embeddings if needed, but for this approach, they are
        # re-generated during initialization based on category names.
        # Saving the state_dict isn't really applicable for this type of model.
        # Returning an empty dict or raising an error might be appropriate.
        # For compatibility with the existing framework, return a minimal dict.
        return {}

    def load_state_dict(self, state_dict):
        # No state to load for this model using this method.
        logger.warning("load_state_dict is not applicable for MultimodalEmbeddingWrapper in similarity mode.")


    @classmethod
    def load_from_checkpoint(cls, model_path, project_id, location, device=DEVICE):
        """Loading from checkpoint is not applicable for this wrapper."""
        logger.warning("load_from_checkpoint is not applicable for MultimodalEmbeddingWrapper in similarity mode.")
        # Return a new instance initialized from scratch
        return cls(project_id=project_id, location=location)
