In [1]:
!pip uninstall -y torch torchvision torchaudio sympy
!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu12 # Replace cu118 with your CUDA version or remove if using CPU
!pip install sympy

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Found existing installation: sympy 1.13.1
Uninstalling sympy-1.13.1:
  Successfully uninstalled sympy-1.13.1
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu12
Collecting torch
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl

In [16]:
from tqdm.notebook import tqdm # Import tqdm
import requests # Added for downloading
import urllib.parse # Added for URL encoding category names
import torch
import torchvision.models as models
import torchvision.transforms as T
# torchvision.datasets.QuickDraw is not used anymore
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
import time
import struct # For unpacking binary data
from struct import unpack
import os
import collections
import logging
import datetime
import csv
# import urllib.request # Removed: No longer downloading
from PIL import Image, ImageDraw


In [17]:
# Configuration
QUICKDRAW_CATEGORIES = [
    'apple', 'cat', 'dog', 'door', 'elephant', 'fish', 'flower', 'grapes',
    'grass', 'house', 'ice cream', 'jail', 'key', 'lion', 'moon', 'nose',
    'pencil', 'rabbit', 'sun', 'tree', 'umbrella', 'van', 'cake', 'airplane',
    'ant', 'banana', 'bed', 'bee', 'bicycle', 'bird', 'book', 'bread', 'bus',
    'elbow', 'ear', 'camera', 'car', 'chair', 'clock', 'cloud', 'hand',
    'computer', 'cookie', 'cow', 'crayon', 'cup', 'eraser', 'carrot', 'drums',
    'eye', 'knife'
]
NUM_TRAIN_SAMPLES_PER_CATEGORY = 5000
NUM_TEST_SAMPLES_PER_CATEGORY = 1000
BATCH_SIZE = 64
QUICKDRAW_CACHE_SIZE=20000
IMAGE_SIZE = (256, 256)
LINE_WIDTH = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BINARY_DATA_ROOT = '.' # MODIFIED: Point to local data directory
NUM_FINETUNE_EPOCHS = 3 # Number of epochs for fine-tuning
FINETUNE_LEARNING_RATE = 1e-4 # Learning rate for fine-tuning

# Create a folder to download dataset if it does not exists
if not os.path.exists(BINARY_DATA_ROOT):
    os.makedirs(BINARY_DATA_ROOT, exist_ok=True)
    print(f"Warning: Data directory '{BINARY_DATA_ROOT}' was not found and has been created.")
    print(f"Please ensure QuickDraw .bin files (e.g., full_binary_apple.bin) for categories {QUICKDRAW_CATEGORIES} are placed there.")

# --- Logging Configuration ---
# Set to logging.DEBUG for verbose development output, logging.INFO for less
LOG_LEVEL = logging.INFO
logging.basicConfig(level=LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
logger = logging.getLogger(__name__)


In [18]:

# --- Part 1: Data Download and Preparation ---
# %%
def download_quickdraw_binary(category_name, download_dir):
    """
    Downloads the .bin file for a given QuickDraw category.
    Files are named 'full_binary_{category_name_underscored}.bin'.
    """
    # Sanitize category name for filename (replace spaces with underscores)
    filename_category_part = category_name.replace(' ', '_')
    local_filename = f"full_binary_{filename_category_part}.bin"
    local_filepath = os.path.join(download_dir, local_filename)

    if os.path.exists(local_filepath):
        print(f"File for '{category_name}' already exists: {local_filepath}")
        return

    # URL encode category name for the download URL (e.g., "ice cream" -> "ice%20cream")
    url_category_part = urllib.parse.quote(category_name)
    url = f"https://storage.googleapis.com/quickdraw_dataset/full/binary/{url_category_part}.bin"

    print(f"Downloading '{category_name}' from {url} to {local_filepath}...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise an exception for HTTP errors

        total_size = int(response.headers.get('content-length', 0))

        with open(local_filepath, 'wb') as f, tqdm(
            desc=category_name,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=8192):
                size = f.write(chunk)
                bar.update(size)
        print(f"Successfully downloaded '{category_name}'.")
    except requests.exceptions.RequestException as e:
        print(f"Error downloading '{category_name}': {e}")
        if os.path.exists(local_filepath): # Clean up partial download
            os.remove(local_filepath)
    except Exception as e:
        print(f"An unexpected error occurred while downloading '{category_name}': {e}")
        if os.path.exists(local_filepath): # Clean up partial download
            os.remove(local_filepath)


print(f"Starting download process for {len(QUICKDRAW_CATEGORIES)} categories into '{BINARY_DATA_ROOT}'...")
for category in QUICKDRAW_CATEGORIES:
    download_quickdraw_binary(category, BINARY_DATA_ROOT)
print("Download process finished.")


Starting download process for 51 categories into '.'...
File for 'apple' already exists: ./full_binary_apple.bin
File for 'cat' already exists: ./full_binary_cat.bin
File for 'dog' already exists: ./full_binary_dog.bin
File for 'door' already exists: ./full_binary_door.bin
File for 'elephant' already exists: ./full_binary_elephant.bin
File for 'fish' already exists: ./full_binary_fish.bin
File for 'flower' already exists: ./full_binary_flower.bin
File for 'grapes' already exists: ./full_binary_grapes.bin
File for 'grass' already exists: ./full_binary_grass.bin
File for 'house' already exists: ./full_binary_house.bin
File for 'ice cream' already exists: ./full_binary_ice_cream.bin
File for 'jail' already exists: ./full_binary_jail.bin
File for 'key' already exists: ./full_binary_key.bin
File for 'lion' already exists: ./full_binary_lion.bin
File for 'moon' already exists: ./full_binary_moon.bin
File for 'nose' already exists: ./full_binary_nose.bin
File for 'pencil' already exists: ./fu

In [None]:
# --- QuickDraw Binary Data Reading Functions (from user - Unchanged) ---
def unpack_drawing(file_handle):
    try:
        key_id, = unpack('Q', file_handle.read(8))
        country_code, = unpack('2s', file_handle.read(2))
        recognized, = unpack('b', file_handle.read(1))
        timestamp, = unpack('I', file_handle.read(4))
        n_strokes, = unpack('H', file_handle.read(2))
        image_strokes = []
        for _ in range(n_strokes):
            n_points, = unpack('H', file_handle.read(2))
            fmt = str(n_points) + 'B'
            if n_points == 0:
                image_strokes.append((tuple(), tuple()))
                continue

            x_bytes = file_handle.read(n_points)
            y_bytes = file_handle.read(n_points)

            if len(x_bytes) < n_points or len(y_bytes) < n_points:
                logger.error(f"Insufficient data for stroke points. Expected {n_points}, got {len(x_bytes)} for x, {len(y_bytes)} for y. Skipping drawing.")
                raise struct.error("Insufficient data for stroke points, likely corrupted drawing record.")

            x = unpack(fmt, x_bytes)
            y = unpack(fmt, y_bytes)
            image_strokes.append((x, y))

        return {
            'key_id': key_id,
            'country_code': country_code,
            'recognized': recognized,
            'timestamp': timestamp,
            'image': image_strokes
        }
    except struct.error as e:
        logger.debug(f"Struct error during unpack_drawing: {e}. File pointer at {file_handle.tell() if hasattr(file_handle, 'tell') else 'N/A'}")
        raise
    except Exception as e:
        logger.error(f"Unexpected error in unpack_drawing: {e}")
        raise


def unpack_drawings(filename): # Unchanged, used by indexing if needed elsewhere
    file_size = os.path.getsize(filename)
    with open(filename, 'rb') as f, tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Unpacking {os.path.basename(filename)}", leave=False) as pbar:
        while True:
            try:
                start_pos = f.tell()
                if start_pos >= file_size:
                    break
                yield unpack_drawing(f)
                pbar.update(f.tell() - start_pos)
            except struct.error:
                logger.debug(f"Struct.error in unpack_drawings, likely end of file or data. File: {filename}")
                break
            except EOFError:
                logger.debug(f"EOFError in unpack_drawings. File: {filename}")
                break

# --- Custom QuickDraw Dataset from Local Binary Files (MODIFIED) ---
class QuickDrawBinaryDataset(Dataset):
    IMAGE_SIZE = (256, 256)
    LINE_WIDTH = 2
    def __init__(self, root, category, transform=None, cache_size=QUICKDRAW_CACHE_SIZE): # cache_size in number of items
        self.root = root
        self.category = category.replace(' ', '_')
        self.transform = transform
        self.filepath = os.path.join(self.root, f"full_binary_{self.category}.bin")

        self.cache_size = cache_size
        self.drawing_cache = collections.OrderedDict() # LRU cache for raw drawing data

        if not os.path.exists(self.filepath):
            raise FileNotFoundError(
                f"Dataset binary file not found: {self.filepath}. Please ensure it exists."
            )

        logger.info(f"Indexing drawings from {self.filepath} for category {self.category}...")
        self.drawing_offsets = []

        file_size = os.path.getsize(self.filepath)
        # Use a try-finally block to ensure the file handle for indexing is closed
        idx_file_handle = None
        try:
            idx_file_handle = open(self.filepath, 'rb')
            with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Indexing {os.path.basename(self.filepath)}", leave=False) as pbar:
                while True:
                    current_pos = idx_file_handle.tell()
                    if current_pos >= file_size:
                        break
                    try:
                        # Read one drawing just to advance the pointer and validate structure
                        _ = unpack_drawing(idx_file_handle)
                        self.drawing_offsets.append(current_pos)
                        pbar.update(idx_file_handle.tell() - current_pos)
                    except struct.error:
                        logger.debug(f"Finished indexing or encountered struct.error at offset {current_pos} in {self.filepath}. Total indexed: {len(self.drawing_offsets)}")
                        break
                    except EOFError:
                        logger.debug(f"EOFError encountered while indexing {self.filepath} at offset {current_pos}. Total indexed: {len(self.drawing_offsets)}")
                        break
                    except Exception as e:
                        logger.error(f"Unexpected error during indexing of {self.filepath} at offset {current_pos}: {e}. Stopping indexing for this file.")
                        break
        finally:
            if idx_file_handle:
                idx_file_handle.close()

        if not self.drawing_offsets:
            logger.warning(f"No drawings were indexed for category {self.category} from {self.filepath}.")
        else:
            logger.info(f"Successfully indexed {len(self.drawing_offsets)} drawings for {self.category}. Cache capacity: {self.cache_size} items.")

    def _render_drawing_to_image(self, drawing_strokes):
        image = Image.new("L", self.IMAGE_SIZE, "white")
        draw = ImageDraw.Draw(image)
        for stroke_x, stroke_y in drawing_strokes:
            if not stroke_x or not stroke_y:
                continue
            if len(stroke_x) == 1:
                draw.point((int(stroke_x[0]), int(stroke_y[0])), fill="black")
            else:
                points = list(zip(stroke_x, stroke_y))
                draw.line(points, fill="black", width=self.LINE_WIDTH)
        return image

    def __len__(self):
        return len(self.drawing_offsets)

    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self.drawing_offsets):
            raise IndexError(f"Index {idx} out of bounds for {len(self.drawing_offsets)} drawings.")

        drawing_data = None
        if idx in self.drawing_cache:
            drawing_data = self.drawing_cache[idx]
            self.drawing_cache.move_to_end(idx) # Mark as recently used
            # logger.debug(f"Cache hit for idx {idx} in {self.category}")
        else:
            # logger.debug(f"Cache miss for idx {idx} in {self.category}. Loading from disk.")
            offset = self.drawing_offsets[idx]
            try:
                # Open file, seek, read one drawing, then close.
                # This is safer for multiprocessing in DataLoader.
                with open(self.filepath, 'rb') as f:
                    f.seek(offset)
                    drawing_data = unpack_drawing(f)
            except Exception as e:
                logger.error(f"Error reading drawing at index {idx}, offset {offset} from {self.filepath}: {e}")
                raise IOError(f"Failed to load drawing {idx} for {self.category}") from e

            if self.cache_size > 0:
                self.drawing_cache[idx] = drawing_data
                if len(self.drawing_cache) > self.cache_size:
                    self.drawing_cache.popitem(last=False) # Remove oldest item (LRU)

        if drawing_data is None: # Should not happen if logic is correct
             raise RuntimeError(f"Drawing data for index {idx} could not be retrieved.")

        pil_image = self._render_drawing_to_image(drawing_data['image'])

        if self.transform:
            pil_image = self.transform(pil_image)

        return pil_image, self.category


# --- Model Definitions and Feature Extractors (Unchanged) ---
MODELS_TO_TEST = {
    "MobileNetV3-Small": {
        "weights": models.MobileNet_V3_Small_Weights.IMAGENET1K_V1,
        "model_fn": models.mobilenet_v3_small,
        "feature_extractor_fn": lambda m: MobileNetV3FeatureExtractor(m)
    },
    # "ShuffleNetV2_x0_5": {
    #     "weights": models.ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
    #     "model_fn": models.shufflenet_v2_x0_5,
    #     "feature_extractor_fn": lambda m: ShuffleNetV2FeatureExtractor(m)
    # },
    "SqueezeNet1_1": {
        "weights": models.SqueezeNet1_1_Weights.IMAGENET1K_V1,
        "model_fn": models.squeezenet1_1,
        "feature_extractor_fn": lambda m: SqueezeNetFeatureExtractor(m)
    },
    # "EfficientNet-B0": {
    #     "weights": models.EfficientNet_B0_Weights.IMAGENET1K_V1,
    #     "model_fn": models.efficientnet_b0,
    #     "feature_extractor_fn": lambda m: EfficientNetFeatureExtractor(m)
    # }
}

class MobileNetV3FeatureExtractor(torch.nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.features = original_model.features
        self.avgpool = original_model.avgpool
        self.flatten = torch.nn.Flatten(1)
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        return x

class ShuffleNetV2FeatureExtractor(torch.nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.conv1 = original_model.conv1
        self.maxpool = original_model.maxpool
        self.stage2 = original_model.stage2
        self.stage3 = original_model.stage3
        self.stage4 = original_model.stage4
        self.conv5 = original_model.conv5
        self.glob_pool = torch.nn.AdaptiveAvgPool2d((1,1))
        self.flatten = torch.nn.Flatten(1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.conv5(x)
        x = self.glob_pool(x)
        x = self.flatten(x)
        return x

class SqueezeNetFeatureExtractor(torch.nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.features = original_model.features
        self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = torch.nn.Flatten(1)
    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = self.flatten(x)
        return x

class EfficientNetFeatureExtractor(torch.nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.features = original_model.features
        self.avgpool = original_model.avgpool
        self.flatten = torch.nn.Flatten(1)
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        return x

# --- Data Loading and Preprocessing (MODIFIED) ---
def get_quickdraw_data(categories, num_train_per_cat, num_test_per_cat, base_model_transform, data_root):
    all_train_datasets = []
    all_test_datasets = []

    quickdraw_specific_transform = T.Compose([
        T.Grayscale(num_output_channels=3),
        base_model_transform
    ])

    for i, category_name in enumerate(tqdm(categories, desc="Processing categories", unit="category")):
        print(f"Loading QuickDraw category: {category_name} from local binary files...")
        try:
            full_category_dataset = QuickDrawBinaryDataset(
                root=data_root, # This will be './data'
                category=category_name,
                transform=quickdraw_specific_transform
                # download=True argument removed
            )

            if len(full_category_dataset) == 0:
                print(f"Warning: No data loaded for category {category_name}. Skipping.")
                continue

            actual_num_train = min(num_train_per_cat, len(full_category_dataset))
            remaining_samples = len(full_category_dataset) - actual_num_train
            actual_num_test = min(num_test_per_cat, remaining_samples)

            if actual_num_train + actual_num_test > len(full_category_dataset):
                 print(f"Warning: Not enough samples in {category_name} for desired train/test split. Adjusting.")
                 actual_num_train = min(num_train_per_cat, len(full_category_dataset))
                 actual_num_test = min(num_test_per_cat, len(full_category_dataset) - actual_num_train)
                 if actual_num_test < 0: actual_num_test = 0

            if actual_num_train == 0 and num_train_per_cat > 0:
                print(f"Warning: Could not get any training samples for {category_name}")
            if actual_num_test == 0 and num_test_per_cat > 0:
                 print(f"Warning: Could not get any testing samples for {category_name} after allocating for training.")

            indices = np.arange(len(full_category_dataset))
            np.random.shuffle(indices)

            train_indices = indices[:actual_num_train]
            test_indices = indices[actual_num_train : actual_num_train + actual_num_test]

            class LabeledSubset(Subset):
                def __init__(self, dataset, indices, label):
                    super().__init__(dataset, indices)
                    self.label = label
                def __getitem__(self, idx):
                    data, _ = super().__getitem__(idx)
                    return data, self.label

            if len(train_indices) > 0:
                all_train_datasets.append(LabeledSubset(full_category_dataset, train_indices, i))
            if len(test_indices) > 0:
                all_test_datasets.append(LabeledSubset(full_category_dataset, test_indices, i))

        except (FileNotFoundError, RuntimeError) as e: # Catch load errors
            print(f"Could not load category {category_name}. Error: {e}. Skipping.")
            continue

    if not all_train_datasets or not all_test_datasets:
        if not any(all_train_datasets) and not any(all_test_datasets):
             raise RuntimeError("No QuickDraw data could be loaded for any category. Aborting.")
        else:
            print("Warning: Some categories failed to load, proceeding with available data.")

    all_train_datasets = [ds for ds in all_train_datasets if ds is not None and len(ds) > 0]
    all_test_datasets = [ds for ds in all_test_datasets if ds is not None and len(ds) > 0]

    if not all_train_datasets or not all_test_datasets:
        raise RuntimeError("No usable QuickDraw train/test data after processing categories. Aborting.")

    return ConcatDataset(all_train_datasets), ConcatDataset(all_test_datasets)


# --- Feature Extraction (MODIFIED) ---
def extract_features(model, dataloader, device, description="Extracting features"):
    model.eval()
    model.to(device)
    features_list = []
    labels_list = []
    for inputs, labels in tqdm(dataloader, desc=description, leave=False, unit="batch"):
        inputs = inputs.to(device)
        with torch.no_grad(): # Ensure no gradients are computed in this block
            outputs = model(inputs)
        features_list.append(outputs.cpu().detach().numpy()) # Detach before converting to numpy
        # Labels are typically not part of the computation graph, but detaching is harmless
        if isinstance(labels, torch.Tensor):
            labels_list.append(labels.cpu().detach().numpy())
        else: # If labels are already numpy arrays or other types
            labels_list.append(labels)

    if not features_list:
        return np.array([]), np.array([])

    return np.concatenate(features_list), np.concatenate(labels_list)

# --- Main Benchmarking Loop (MODIFIED calls to extract_features) ---
def run_benchmark():
    results = []
    logger.info(f"Using device: {DEVICE}")
    logger.info(f"Benchmarking on QuickDraw categories: {', '.join(QUICKDRAW_CATEGORIES)}")
    logger.info(f"Samples per category: {NUM_TRAIN_SAMPLES_PER_CATEGORY} train, {NUM_TEST_SAMPLES_PER_CATEGORY} test.")
    logger.info(f"Loading QuickDraw data from local directory: {os.path.abspath(BINARY_DATA_ROOT)}\\n")


    for model_name, config in tqdm(MODELS_TO_TEST.items(), desc="Benchmarking Models (Feature Extraction)", unit="model"):
        logger.info(f"--- Testing Model (Feature Extraction): {model_name} ---")

        current_model_accuracy = "Error"
        current_model_acc_per_param = "Error"
        current_model_feat_ext_time = "N/A"
        current_model_train_time = "N/A"
        current_model_lr_inference_time = "N/A"
        current_model_lr_inference_speed = "N/A"
        model_params = "Error" # Initialize in case model loading fails

        try:
            weights = config["weights"]
            base_model = config["model_fn"](weights=weights)
            model_params = sum(p.numel() for p in base_model.parameters() if p.requires_grad) / 1_000_000

            model_specific_transforms = weights.transforms()

            train_dataset, test_dataset = get_quickdraw_data(
                QUICKDRAW_CATEGORIES,
                NUM_TRAIN_SAMPLES_PER_CATEGORY,
                NUM_TEST_SAMPLES_PER_CATEGORY,
                model_specific_transforms,
                BINARY_DATA_ROOT
            )

            if len(train_dataset) == 0 or len(test_dataset) == 0:
                logger.warning(f"Not enough data loaded for {model_name} to proceed. Skipping.")
                raise RuntimeError("Empty dataset after loading.")

            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False) # Shuffle False for feature extraction is fine
            test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

            feature_extractor = config["feature_extractor_fn"](base_model)
            feature_extractor.to(DEVICE)
            feature_extractor.eval()

            logger.info("Extracting training features...")
            start_time = time.time()
            train_features, train_labels = extract_features(feature_extractor, train_loader, DEVICE, description=f"Extracting train features ({model_name})")
            feat_ext_time_train = time.time() - start_time

            if train_features.size == 0:
                logger.warning(f"No training features extracted for {model_name}. Skipping further steps for this model.")
                raise RuntimeError("No training features.")
            logger.debug(f"Train features shape: {train_features.shape}, Train labels shape: {train_labels.shape}")

            logger.info("Extracting test features...")
            start_time = time.time()
            test_features, test_labels = extract_features(feature_extractor, test_loader, DEVICE, description=f"Extracting test features ({model_name})")
            feat_ext_time_test = time.time() - start_time

            if test_features.size == 0:
                logger.warning(f"No test features extracted for {model_name}. Skipping further steps for this model.")
                raise RuntimeError("No test features.")
            logger.debug(f"Test features shape: {test_features.shape}, Test labels shape: {test_labels.shape}")

            current_model_feat_ext_time = f"{feat_ext_time_train + feat_ext_time_test:.2f}"

            logger.info("Training Logistic Regression classifier...")
            start_time = time.time()
            classifier = make_pipeline(StandardScaler(), LogisticRegression(max_iter=1000, random_state=42, solver='liblinear', C=0.1))
            classifier.fit(train_features, train_labels)
            current_model_train_time = f"{time.time() - start_time:.2f}"

            accuracy = classifier.score(test_features, test_labels) * 100
            logger.info(f"QuickDraw Test Accuracy for {model_name} (LogReg): {accuracy:.2f}%")
            current_model_accuracy = f"{accuracy:.2f}"
            current_model_acc_per_param = f"{accuracy / model_params if isinstance(model_params, (int, float)) and model_params > 0 else 0:.2f}"

            # Calculate Logistic Regression Inference Time and Speed
            start_time_lr_inference = time.time()
            _ = classifier.predict(test_features) # Perform prediction
            lr_inference_duration = time.time() - start_time_lr_inference
            current_model_lr_inference_time = f"{lr_inference_duration:.4f}" # Higher precision for inference time
            num_test_samples_lr = len(test_labels)
            current_model_lr_inference_speed = f"{num_test_samples_lr / lr_inference_duration if lr_inference_duration > 0 else 0:.2f}"
            logger.info(f"Logistic Regression Inference Time: {current_model_lr_inference_time}s, Speed: {current_model_lr_inference_speed} samples/s")


        except RuntimeError as e:
            logger.error(f"Runtime error during feature extraction processing for {model_name}: {e}")
        except Exception as e:
            logger.exception(f"An unexpected error occurred for model {model_name} (Feature Extraction): {e}") # Use logger.exception for traceback
        finally:
            results.append({
                "Model": model_name,
                "Accuracy (%)": current_model_accuracy,
                "Params (M)": f"{model_params:.2f}" if isinstance(model_params, (int, float)) else model_params,
                "Acc/Params": current_model_acc_per_param,
                "Feat Ext Time (s)": current_model_feat_ext_time,
                "LR Train Time (s)": current_model_train_time,
                "LR Inference Time (s)": current_model_lr_inference_time,
                "LR Inference Speed (samples/s)": current_model_lr_inference_speed
            })
            logger.info("-" * 30 + "\\n")
            if 'base_model' in locals(): del base_model
            if 'feature_extractor' in locals(): del feature_extractor
            if 'train_dataset' in locals(): del train_dataset
            if 'test_dataset' in locals(): del test_dataset
            if 'train_loader' in locals(): del train_loader
            if 'test_loader' in locals(): del test_loader
            if 'train_features' in locals(): del train_features
            if 'train_labels' in locals(): del train_labels
            if 'test_features' in locals(): del test_features
            if 'test_labels' in locals(): del test_labels
            if DEVICE == torch.device("cuda"):
                torch.cuda.empty_cache()

    print("\\n--- Feature Extraction Benchmark Results ---")
    if not results:
        print("No results to display for feature extraction benchmark.")
        return

    headers = results[0].keys()
    col_widths = {key: len(key) for key in headers}
    for row in results:
        for key in headers:
            col_widths[key] = max(col_widths.get(key, 0), len(str(row.get(key, ""))))

    for key in col_widths:
        col_widths[key] += 2

    header_line = " | ".join(f"{h:<{col_widths[h]}}" for h in headers)
    print(header_line)
    print("-" * len(header_line))

    for row in results:
        row_line = " | ".join(f"{str(row.get(h, 'N/A')):<{col_widths[h]}}" for h in headers)
        print(row_line)

    if results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        results_filename = f"feature_extraction_benchmark_results_{timestamp}.csv"
        try:
            with open(results_filename, 'w', newline='') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=headers)
                writer.writeheader()
                writer.writerows(results)
            logger.info(f"Feature extraction benchmark results saved to {results_filename}")
        except IOError as e:
            logger.error(f"Could not save feature extraction results to CSV: {e}")

# --- Fine-tuning Benchmarking Loop (NEW) ---
def run_finetuning_benchmark():
    finetuning_results = []
    logger.info(f"\\n--- Starting Fine-tuning Benchmark ---")
    logger.info(f"Using device: {DEVICE}")
    logger.info(f"Fine-tuning on QuickDraw categories: {', '.join(QUICKDRAW_CATEGORIES)}")
    logger.info(f"Samples per category: {NUM_TRAIN_SAMPLES_PER_CATEGORY} train, {NUM_TEST_SAMPLES_PER_CATEGORY} test.")
    logger.info(f"Number of fine-tuning epochs: {NUM_FINETUNE_EPOCHS}, Learning Rate: {FINETUNE_LEARNING_RATE}")
    logger.info(f"Loading QuickDraw data from local directory: {os.path.abspath(BINARY_DATA_ROOT)}\\n")

    num_classes = len(QUICKDRAW_CATEGORIES)

    for model_name, config in tqdm(MODELS_TO_TEST.items(), desc="Benchmarking Models (Fine-tuning)", unit="model"):
        logger.info(f"--- Fine-tuning Model: {model_name} ---")

        current_ft_accuracy = "Error"
        current_ft_train_time = "N/A"
        current_ft_inference_time = "N/A"
        current_ft_inference_speed = "N/A"
        model_params_ft = "Error"

        try:
            weights = config["weights"]
            model_to_finetune = config["model_fn"](weights=weights)
            model_params_ft = sum(p.numel() for p in model_to_finetune.parameters() if p.requires_grad) / 1_000_000

            # Modify classifier head
            if model_name == "MobileNetV3-Small":
                in_features = model_to_finetune.classifier[3].in_features
                model_to_finetune.classifier[3] = torch.nn.Linear(in_features, num_classes)
            elif model_name == "SqueezeNet1_1":
                # SqueezeNet's classifier is a Conv2d layer
                in_channels = model_to_finetune.classifier[1].in_channels
                model_to_finetune.classifier[1] = torch.nn.Conv2d(in_channels, num_classes, kernel_size=1)
            else:
                logger.warning(f"Classifier modification not defined for {model_name}. Skipping fine-tuning.")
                continue

            model_to_finetune.to(DEVICE)

            model_specific_transforms = weights.transforms()
            train_dataset_ft, test_dataset_ft = get_quickdraw_data(
                QUICKDRAW_CATEGORIES,
                NUM_TRAIN_SAMPLES_PER_CATEGORY,
                NUM_TEST_SAMPLES_PER_CATEGORY,
                model_specific_transforms,
                BINARY_DATA_ROOT
            )

            if len(train_dataset_ft) == 0 or len(test_dataset_ft) == 0:
                logger.warning(f"Not enough data loaded for {model_name} (fine-tuning) to proceed. Skipping.")
                raise RuntimeError("Empty dataset after loading for fine-tuning.")

            train_loader_ft = DataLoader(train_dataset_ft, batch_size=BATCH_SIZE, shuffle=True) # Shuffle True for training
            test_loader_ft = DataLoader(test_dataset_ft, batch_size=BATCH_SIZE, shuffle=False)

            optimizer = torch.optim.AdamW(model_to_finetune.parameters(), lr=FINETUNE_LEARNING_RATE)
            criterion = torch.nn.CrossEntropyLoss()

            # Fine-tuning loop
            logger.info(f"Starting fine-tuning for {model_name} for {NUM_FINETUNE_EPOCHS} epochs...")
            start_finetune_time = time.time()
            for epoch in range(NUM_FINETUNE_EPOCHS):
                model_to_finetune.train()
                running_loss = 0.0
                for inputs, labels in tqdm(train_loader_ft, desc=f"Epoch {epoch+1}/{NUM_FINETUNE_EPOCHS} Training", leave=False, unit="batch"):
                    inputs, labels = inputs.to(DEVICE), labels.to(DEVICE).long() # Ensure labels are LongTensor

                    optimizer.zero_grad()
                    outputs = model_to_finetune(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item() * inputs.size(0)

                epoch_loss = running_loss / len(train_loader_ft.dataset)
                logger.info(f"Epoch {epoch+1}/{NUM_FINETUNE_EPOCHS}, Train Loss: {epoch_loss:.4f}")

            current_ft_train_time = f"{time.time() - start_finetune_time:.2f}"
            logger.info(f"Fine-tuning for {model_name} completed in {current_ft_train_time}s.")

            # Evaluation loop for fine-tuned model
            logger.info(f"Evaluating fine-tuned {model_name}...")
            model_to_finetune.eval()
            correct = 0
            total = 0
            start_ft_inference_time = time.time()
            with torch.no_grad():
                for inputs, labels in tqdm(test_loader_ft, desc=f"Evaluating {model_name}", leave=False, unit="batch"):
                    inputs, labels = inputs.to(DEVICE), labels.to(DEVICE).long()
                    outputs = model_to_finetune(inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

            ft_inference_duration = time.time() - start_ft_inference_time
            current_ft_inference_time = f"{ft_inference_duration:.4f}"
            current_ft_accuracy = f"{100 * correct / total if total > 0 else 0:.2f}"
            current_ft_inference_speed = f"{total / ft_inference_duration if ft_inference_duration > 0 else 0:.2f}"

            logger.info(f"Fine-tuned Test Accuracy for {model_name}: {current_ft_accuracy}%")
            logger.info(f"Fine-tuned Inference Time: {current_ft_inference_time}s, Speed: {current_ft_inference_speed} samples/s")

        except RuntimeError as e:
            logger.error(f"Runtime error during fine-tuning for {model_name}: {e}")
        except Exception as e:
            logger.exception(f"An unexpected error occurred for model {model_name} (Fine-tuning): {e}")
        finally:
            finetuning_results.append({
                "Model": model_name,
                "FT Accuracy (%)": current_ft_accuracy,
                "Params (M)": f"{model_params_ft:.2f}" if isinstance(model_params_ft, (int, float)) else model_params_ft,
                "FT Train Time (s)": current_ft_train_time,
                "FT Inference Time (s)": current_ft_inference_time,
                "FT Inference Speed (samples/s)": current_ft_inference_speed
            })
            logger.info("-" * 30 + "\\n")
            if 'model_to_finetune' in locals(): del model_to_finetune
            if 'train_dataset_ft' in locals(): del train_dataset_ft
            if 'test_dataset_ft' in locals(): del test_dataset_ft
            if 'train_loader_ft' in locals(): del train_loader_ft
            if 'test_loader_ft' in locals(): del test_loader_ft
            if DEVICE == torch.device("cuda"):
                torch.cuda.empty_cache()

    print("\\n--- Fine-tuning Benchmark Results ---")
    if not finetuning_results:
        print("No results to display for fine-tuning benchmark.")
        return

    ft_headers = finetuning_results[0].keys()
    ft_col_widths = {key: len(key) for key in ft_headers}
    for row in finetuning_results:
        for key in ft_headers:
            ft_col_widths[key] = max(ft_col_widths.get(key, 0), len(str(row.get(key, ""))))

    for key in ft_col_widths:
        ft_col_widths[key] += 2

    ft_header_line = " | ".join(f"{h:<{ft_col_widths[h]}}" for h in ft_headers)
    print(ft_header_line)
    print("-" * len(ft_header_line))

    for row in finetuning_results:
        row_line = " | ".join(f"{str(row.get(h, 'N/A')):<{ft_col_widths[h]}}" for h in ft_headers)
        print(row_line)

    if finetuning_results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        ft_results_filename = f"finetuning_benchmark_results_{timestamp}.csv"
        try:
            with open(ft_results_filename, 'w', newline='') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=ft_headers)
                writer.writeheader()
                writer.writerows(finetuning_results)
            logger.info(f"Fine-tuning benchmark results saved to {ft_results_filename}")
        except IOError as e:
            logger.error(f"Could not save fine-tuning results to CSV: {e}")



if __name__ == '__main__':
    run_benchmark()
    run_finetuning_benchmark()

Benchmarking Models (Feature Extraction):   0%|          | 0/2 [00:00<?, ?model/s]

Processing categories:   0%|          | 0/51 [00:00<?, ?category/s]

Loading QuickDraw category: apple from local binary files...


Indexing full_binary_apple.bin:   0%|          | 0.00/13.8M [00:00<?, ?B/s]

Loading QuickDraw category: cat from local binary files...


Indexing full_binary_cat.bin:   0%|          | 0.00/19.6M [00:00<?, ?B/s]

Loading QuickDraw category: dog from local binary files...


Indexing full_binary_dog.bin:   0%|          | 0.00/23.5M [00:00<?, ?B/s]

Loading QuickDraw category: door from local binary files...


Indexing full_binary_door.bin:   0%|          | 0.00/8.54M [00:00<?, ?B/s]

Loading QuickDraw category: elephant from local binary files...


Indexing full_binary_elephant.bin:   0%|          | 0.00/18.8M [00:00<?, ?B/s]

Loading QuickDraw category: fish from local binary files...


Indexing full_binary_fish.bin:   0%|          | 0.00/12.1M [00:00<?, ?B/s]

Loading QuickDraw category: flower from local binary files...


Indexing full_binary_flower.bin:   0%|          | 0.00/21.5M [00:00<?, ?B/s]

Loading QuickDraw category: grapes from local binary files...


Indexing full_binary_grapes.bin:   0%|          | 0.00/32.9M [00:00<?, ?B/s]

Loading QuickDraw category: grass from local binary files...


Indexing full_binary_grass.bin:   0%|          | 0.00/11.9M [00:00<?, ?B/s]

Loading QuickDraw category: house from local binary files...


Indexing full_binary_house.bin:   0%|          | 0.00/10.9M [00:00<?, ?B/s]

Loading QuickDraw category: ice cream from local binary files...


Indexing full_binary_ice_cream.bin:   0%|          | 0.00/11.9M [00:00<?, ?B/s]

Loading QuickDraw category: jail from local binary files...


Indexing full_binary_jail.bin:   0%|          | 0.00/11.5M [00:00<?, ?B/s]

Loading QuickDraw category: key from local binary files...


Indexing full_binary_key.bin:   0%|          | 0.00/16.9M [00:00<?, ?B/s]

Loading QuickDraw category: lion from local binary files...


Indexing full_binary_lion.bin:   0%|          | 0.00/24.7M [00:00<?, ?B/s]

Loading QuickDraw category: moon from local binary files...


Indexing full_binary_moon.bin:   0%|          | 0.00/11.6M [00:00<?, ?B/s]

Loading QuickDraw category: nose from local binary files...


Indexing full_binary_nose.bin:   0%|          | 0.00/13.8M [00:00<?, ?B/s]

Loading QuickDraw category: pencil from local binary files...


Indexing full_binary_pencil.bin:   0%|          | 0.00/10.1M [00:00<?, ?B/s]

Loading QuickDraw category: rabbit from local binary files...


Indexing full_binary_rabbit.bin:   0%|          | 0.00/25.3M [00:00<?, ?B/s]

Loading QuickDraw category: sun from local binary files...


Indexing full_binary_sun.bin:   0%|          | 0.00/15.9M [00:00<?, ?B/s]

Loading QuickDraw category: tree from local binary files...


Indexing full_binary_tree.bin:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Loading QuickDraw category: umbrella from local binary files...


Indexing full_binary_umbrella.bin:   0%|          | 0.00/11.5M [00:00<?, ?B/s]

Loading QuickDraw category: van from local binary files...


Indexing full_binary_van.bin:   0%|          | 0.00/21.7M [00:00<?, ?B/s]

Loading QuickDraw category: cake from local binary files...


Indexing full_binary_cake.bin:   0%|          | 0.00/17.8M [00:00<?, ?B/s]

Loading QuickDraw category: airplane from local binary files...


Indexing full_binary_airplane.bin:   0%|          | 0.00/15.7M [00:00<?, ?B/s]

Loading QuickDraw category: ant from local binary files...


Indexing full_binary_ant.bin:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

Loading QuickDraw category: banana from local binary files...


Indexing full_binary_banana.bin:   0%|          | 0.00/25.1M [00:00<?, ?B/s]

Loading QuickDraw category: bed from local binary files...


Indexing full_binary_bed.bin:   0%|          | 0.00/11.1M [00:00<?, ?B/s]

Loading QuickDraw category: bee from local binary files...


Indexing full_binary_bee.bin:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Loading QuickDraw category: bicycle from local binary files...


Indexing full_binary_bicycle.bin:   0%|          | 0.00/18.4M [00:00<?, ?B/s]

Loading QuickDraw category: bird from local binary files...


Indexing full_binary_bird.bin:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

Loading QuickDraw category: book from local binary files...


Indexing full_binary_book.bin:   0%|          | 0.00/14.1M [00:00<?, ?B/s]

Loading QuickDraw category: bread from local binary files...


Indexing full_binary_bread.bin:   0%|          | 0.00/9.74M [00:00<?, ?B/s]

Loading QuickDraw category: bus from local binary files...


Indexing full_binary_bus.bin:   0%|          | 0.00/25.3M [00:00<?, ?B/s]

Loading QuickDraw category: elbow from local binary files...


Indexing full_binary_elbow.bin:   0%|          | 0.00/9.56M [00:00<?, ?B/s]

Loading QuickDraw category: ear from local binary files...


Indexing full_binary_ear.bin:   0%|          | 0.00/10.2M [00:00<?, ?B/s]

Loading QuickDraw category: camera from local binary files...


Indexing full_binary_camera.bin:   0%|          | 0.00/13.6M [00:00<?, ?B/s]

Loading QuickDraw category: car from local binary files...


Indexing full_binary_car.bin:   0%|          | 0.00/24.4M [00:00<?, ?B/s]

Loading QuickDraw category: chair from local binary files...


Indexing full_binary_chair.bin:   0%|          | 0.00/17.6M [00:00<?, ?B/s]

Loading QuickDraw category: clock from local binary files...


Indexing full_binary_clock.bin:   0%|          | 0.00/12.9M [00:00<?, ?B/s]

Loading QuickDraw category: cloud from local binary files...


Indexing full_binary_cloud.bin:   0%|          | 0.00/13.3M [00:00<?, ?B/s]

Loading QuickDraw category: hand from local binary files...


Indexing full_binary_hand.bin:   0%|          | 0.00/31.2M [00:00<?, ?B/s]

Loading QuickDraw category: computer from local binary files...


Indexing full_binary_computer.bin:   0%|          | 0.00/14.3M [00:00<?, ?B/s]

Loading QuickDraw category: cookie from local binary files...


Indexing full_binary_cookie.bin:   0%|          | 0.00/20.6M [00:00<?, ?B/s]

Loading QuickDraw category: cow from local binary files...


Indexing full_binary_cow.bin:   0%|          | 0.00/25.6M [00:00<?, ?B/s]

Loading QuickDraw category: crayon from local binary files...


Indexing full_binary_crayon.bin:   0%|          | 0.00/11.3M [00:00<?, ?B/s]

Loading QuickDraw category: cup from local binary files...


Indexing full_binary_cup.bin:   0%|          | 0.00/12.7M [00:00<?, ?B/s]

Loading QuickDraw category: eraser from local binary files...


Indexing full_binary_eraser.bin:   0%|          | 0.00/10.9M [00:00<?, ?B/s]

Loading QuickDraw category: carrot from local binary files...


Indexing full_binary_carrot.bin:   0%|          | 0.00/14.1M [00:00<?, ?B/s]

Loading QuickDraw category: drums from local binary files...


Indexing full_binary_drums.bin:   0%|          | 0.00/19.7M [00:00<?, ?B/s]

Loading QuickDraw category: eye from local binary files...


Indexing full_binary_eye.bin:   0%|          | 0.00/16.9M [00:00<?, ?B/s]

Loading QuickDraw category: knife from local binary files...


Indexing full_binary_knife.bin:   0%|          | 0.00/12.6M [00:00<?, ?B/s]

Extracting train features (MobileNetV3-Small):   0%|          | 0/3985 [00:00<?, ?batch/s]