In [None]:
!pip uninstall -y -q torch torchvision sympy
!pip install -q torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121 # Replace cu118 with your CUDA version or remove if using CPU
!pip install -q sympy

In [None]:
from tqdm.notebook import tqdm # Import tqdm
import requests # Added for downloading
import urllib.parse # Added for URL encoding category names
import torch
import torchvision.models as models
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
import time
import struct # For unpacking binary data
from struct import unpack
import os
import collections
import logging
from datetime import datetime
import csv
import json
import gzip
import pickle
from torch import nn
# import urllib.request # Removed: No longer downloading
from PIL import Image, ImageDraw
try:
    import git
    GIT_AVAILABLE = True
except ImportError:
    GIT_AVAILABLE = False
    print("gitpython not installed. Using default versioning.")


In [None]:
# Configuration
QUICKDRAW_CATEGORIES = [
    'apple', 'cat', 'dog', 'door', 'elephant', 'fish', 'flower', 'table',
    'grass', 'house', 'ice cream', 'circle', 'key', 'lion', 'moon', 'nose',
    'pencil', 'rabbit', 'sun', 'tree', 'umbrella', 'van', 'cake', 'airplane',
    'ant', 'banana', 'bed', 'bee', 'bicycle', 'bird', 'book', 'bread', 'bus',
    'elbow', 'ear', 'camera', 'car', 'chair', 'clock', 'cloud', 'hand',
    'computer', 'cookie', 'cow', 'crayon', 'cup', 'eraser', 'carrot', 'drums',
    'eye', 'knife'
]
# NUM_TRAIN_SAMPLES_PER_CATEGORY = 5000
NUM_TRAIN_SAMPLES_PER_CATEGORY = 100
# NUM_TEST_SAMPLES_PER_CATEGORY = 1000
NUM_TEST_SAMPLES_PER_CATEGORY = 20
QUICKDRAW_CACHE_SIZE=50000  # Increased from 20000 to reduce cache misses with larger datasets
IMAGE_SIZE = (256, 256)
LINE_WIDTH = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BINARY_DATA_ROOT = './data' # MODIFIED: Point to local data directory

# Configuration section - add dynamic batch size scaling
BATCH_SIZE = 128  # Increase base batch size for smaller datasets
# Add adaptive batch size based on dataset size
MAX_BATCH_SIZE = 128  # Increased for ml.g4dn.xlarge (16GB GPU)
MIN_BATCH_SIZE = 16   # Reduced from 32 to prevent OOM errors with large datasets

# Add a function to calculate appropriate batch size
def get_adaptive_batch_size(num_samples_per_category, num_categories):
    """Calculate appropriate batch size based on dataset size to prevent OOM errors"""
    total_samples = num_samples_per_category * num_categories
    
    if total_samples < 10000:  # Small dataset
        return MAX_BATCH_SIZE
    elif total_samples < 50000:  # Medium dataset
        return 64  # Increased from 32
    elif total_samples < 200000:  # Large dataset
        return 32  # Increased from 16
    else:  # Very large dataset (10000 samples x 50+ categories)
        return MIN_BATCH_SIZE
    
# Fine-tuning hyperparameters - updated for better training
NUM_FINETUNE_EPOCHS = 5               # Set to 20 as specified
FINETUNE_LEARNING_RATE = 5e-4          # Slightly increased from 1e-4
FINETUNE_WEIGHT_DECAY = 1e-5           # Added weight decay for regularization
MODEL_SAVE_PATH = './models'           # Directory to save fine-tuned models
VALIDATION_SPLIT = 0.1                 # Percentage of training data to use for validation
USE_GRADUAL_UNFREEZING = True          # Whether to use gradual unfreezing
USE_DATA_AUGMENTATION = True           # Whether to use data augmentation
GRADIENT_ACCUMULATION_STEPS = 1        # Default: update weights after every batch
USE_GRADIENT_CHECKPOINTING = True      # Enable gradient checkpointing to save memory
CHECKPOINT_INTERVAL = 5                # Save checkpoints every N epochs
RESUME_FROM_CHECKPOINT = True          # Whether to resume from checkpoint if available


# --- Model Definitions and Feature Extractors (Unchanged) ---
MODELS_TO_TEST = {
    "MobileNetV3-Small": {
        "weights": models.MobileNet_V3_Small_Weights.IMAGENET1K_V1,
        "model_fn": models.mobilenet_v3_small,
    },
    "ShuffleNetV2_x0_5": {
        "weights": models.ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
        "model_fn": models.shufflenet_v2_x0_5,
    },
    "SqueezeNet1_1": {
        "weights": models.SqueezeNet1_1_Weights.IMAGENET1K_V1,
        "model_fn": models.squeezenet1_1,
    },
    # "EfficientNet-B0": {
    #     "weights": models.EfficientNet_B0_Weights.IMAGENET1K_V1,
    #     "model_fn": models.efficientnet_b0,
    #  
    # }
}

# --- Logging Configuration ---
# Set to logging.DEBUG for verbose development output, logging.INFO for less
LOG_LEVEL = logging.INFO

# Create a logs directory if it doesn't exist
LOGS_DIR = './logs'
if not os.path.exists(LOGS_DIR):
    os.makedirs(LOGS_DIR, exist_ok=True)

# Generate a log file name based on the notebook file name, datetime, and environment
notebook_name = 'quickdraw_benchmark'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
log_file_name = f"{notebook_name}_{timestamp}_{device_type}.log"
log_file_path = os.path.join(LOGS_DIR, log_file_name)

# Get the root logger (or a specific logger, __name__ is fine)
# Avoid using basicConfig if you need fine-grained handler control after creation
# basicConfig configures the root logger, but we can get it and clear handlers
# or just create our own logger from scratch. Let's create our own logger explicitly.
logger = logging.getLogger(__name__)
logger.setLevel(LOG_LEVEL) # Set the level for *this* logger

# Prevent duplicate handlers if the cell is run multiple times
if logger.hasHandlers():
    logger.handlers.clear()

# Create handlers manually
# File Handler: Use buffering=1 for line buffering (most common for text) or 0 for no buffering
# For binary data like the .bin files are processed from, default buffering applies.
# However, the FileHandler *itself* writes text logs, so we can try line buffering.
# If that's not sufficient, we could force flush periodically.
try:
    # Using buffering=1 for line buffering in text mode ('w') is standard,
    # but FileHandler uses 'a' by default. Let's try 'a' with a smaller buffer if possible,
    # or just force flushing. Manual flushing is more reliable for immediate write.
    file_handler = logging.FileHandler(log_file_path, mode='w') # Use 'w' to overwrite each run or 'a' to append
    file_handler.setLevel(LOG_LEVEL)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(module)s - %(message)s')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Stream Handler for console output
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(LOG_LEVEL)
    stream_handler.setFormatter(formatter) # Use the same formatter
    logger.addHandler(stream_handler)

    logger.info(f"Logging initialized. Logs will be saved to {log_file_path}")

except Exception as e:
    # Fallback: logger.info an error if logging setup fails
    logger.info(f"Error setting up logging handlers: {e}", flush=True)
    # Revert to basicConfig just for console output if file logging failed
    logging.basicConfig(level=LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
    logger = logging.getLogger(__name__) # Get the root logger now configured by basicConfig
    logger.warning("File logging setup failed, falling back to console-only logging.")
    logger.info(f"Logging initialized (console only). Failed to save to {log_file_path}")




# Create a folder to download dataset if it does not exists
if not os.path.exists(BINARY_DATA_ROOT):
    os.makedirs(BINARY_DATA_ROOT, exist_ok=True)
    logger.warning(f"Warning: Data directory '{BINARY_DATA_ROOT}' was not found and has been created.")
    logger.info(f"Please ensure QuickDraw .bin files (e.g., full_binary_apple.bin) for categories {QUICKDRAW_CATEGORIES} are placed there.")

# Create directory for saving models if it doesn't exist
if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
    logger.info(f"Created directory for saving models: {MODEL_SAVE_PATH}")

def clear_gpu_memory():
    """Clear GPU cache to free memory"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def get_git_info():
    """Get the current git commit hash and message"""
    if not GIT_AVAILABLE:
        return {"available": False, "commit": "no_git", "message": "no_git"}
    
    try:
        repo = git.Repo(search_parent_directories=True)
        commit_hash = repo.head.commit.hexsha[:8]  # Short hash
        commit_message = repo.head.commit.message.strip().split('\n')[0]  # First line only
        # Replace spaces and special chars for filename safety
        safe_message = commit_message.replace(' ', '_').replace('/', '-').replace(':', '-')[:30]
        return {
            "available": True,
            "commit": commit_hash,
            "message": safe_message
        }
    except (git.InvalidGitRepositoryError, git.NoSuchPathError):
        return {"available": False, "commit": "no_git", "message": "no_git"}

In [None]:

# --- Part 1: Data Download and Preparation ---
# %%
def download_quickdraw_binary(category_name, download_dir):
    """
    Downloads the .bin file for a given QuickDraw category.
    Files are named 'full_binary_{category_name_underscored}.bin'.
    """
    # Sanitize category name for filename (replace spaces with underscores)
    filename_category_part = category_name.replace(' ', '_')
    local_filename = f"full_binary_{filename_category_part}.bin"
    local_filepath = os.path.join(download_dir, local_filename)

    if os.path.exists(local_filepath):
        logger.info(f"File for '{category_name}' already exists: {local_filepath}")
        return

    # URL encode category name for the download URL (e.g., "ice cream" -> "ice%20cream")
    url_category_part = urllib.parse.quote(category_name)
    url = f"https://storage.googleapis.com/quickdraw_dataset/full/binary/{url_category_part}.bin"

    logger.info(f"Downloading '{category_name}' from {url} to {local_filepath}...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise an exception for HTTP errors

        total_size = int(response.headers.get('content-length', 0))

        with open(local_filepath, 'wb') as f, tqdm(
            desc=category_name,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=8192):
                size = f.write(chunk)
                bar.update(size)
        logger.info(f"Successfully downloaded '{category_name}'.")
    except requests.exceptions.RequestException as e:
        logger.info(f"Error downloading '{category_name}': {e}")
        if os.path.exists(local_filepath): # Clean up partial download
            os.remove(local_filepath)
    except Exception as e:
        logger.info(f"An unexpected error occurred while downloading '{category_name}': {e}")
        if os.path.exists(local_filepath): # Clean up partial download
            os.remove(local_filepath)


logger.info(f"Starting download process for {len(QUICKDRAW_CATEGORIES)} categories into '{BINARY_DATA_ROOT}'...")
for category in QUICKDRAW_CATEGORIES:
    download_quickdraw_binary(category, BINARY_DATA_ROOT)
logger.info("Download process finished.")

# Explicitly flush handlers after a significant phase
for handler in logger.handlers:
    if isinstance(handler, logging.FileHandler):
        handler.flush()



# QuickDrawBinaryDataset Loader and Indexer

In [None]:
# --- QuickDraw Binary Data Reading Functions (from user - Unchanged) ---
def unpack_drawing(file_handle):
    try:
        key_id, = unpack('Q', file_handle.read(8))
        country_code, = unpack('2s', file_handle.read(2))
        recognized, = unpack('b', file_handle.read(1))
        timestamp, = unpack('I', file_handle.read(4))
        n_strokes, = unpack('H', file_handle.read(2))
        image_strokes = []
        for _ in range(n_strokes):
            n_points, = unpack('H', file_handle.read(2))
            fmt = str(n_points) + 'B'
            if n_points == 0:
                image_strokes.append((tuple(), tuple()))
                continue

            x_bytes = file_handle.read(n_points)
            y_bytes = file_handle.read(n_points)

            if len(x_bytes) < n_points or len(y_bytes) < n_points:
                logger.error(f"Insufficient data for stroke points. Expected {n_points}, got {len(x_bytes)} for x, {len(y_bytes)} for y. Skipping drawing.")
                raise struct.error("Insufficient data for stroke points, likely corrupted drawing record.")

            x = unpack(fmt, x_bytes)
            y = unpack(fmt, y_bytes)
            image_strokes.append((x, y))

        return {
            'key_id': key_id,
            'country_code': country_code,
            'recognized': recognized,
            'timestamp': timestamp,
            'image': image_strokes
        }
    except struct.error as e:
        logger.debug(f"Struct error during unpack_drawing: {e}. File pointer at {file_handle.tell() if hasattr(file_handle, 'tell') else 'N/A'}")
        raise
    except Exception as e:
        logger.error(f"Unexpected error in unpack_drawing: {e}")
        raise


def unpack_drawings(filename): # Unchanged, used by indexing if needed elsewhere
    file_size = os.path.getsize(filename)
    with open(filename, 'rb') as f, tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Unpacking {os.path.basename(filename)}", leave=False) as pbar:
        while True:
            try:
                start_pos = f.tell()
                if start_pos >= file_size:
                    break
                yield unpack_drawing(f)
                pbar.update(f.tell() - start_pos)
            except struct.error:
                logger.debug(f"Struct.error in unpack_drawings, likely end of file or data. File: {filename}")
                break
            except EOFError:
                logger.debug(f"EOFError in unpack_drawings. File: {filename}")
                break

def precompute_all_indices(root_dir, categories):
    """
    Precompute indices for all categories in the binary dataset.
    This function reads each binary file, extracts drawing offsets, and saves them as index files.
    """
    os.makedirs(root_dir, exist_ok=True)
    
    for category in tqdm(categories, desc="Precomputing indices"):
        # Convert category name for filename
        category_file = category.replace(' ', '_')
        filepath = os.path.join(root_dir, f"full_binary_{category_file}.bin")
        index_path = os.path.join(root_dir, f"full_binary_{category_file}.idx")
        
        if os.path.exists(index_path):
            logger.info(f"Index for {category} already exists at {index_path}")
            continue
            
        if not os.path.exists(filepath):
            logger.warning(f"Binary file for {category} not found at {filepath}, skipping index creation")
            continue
            
        logger.info(f"Creating index for {category}...")
        drawing_offsets = []
        
        file_size = os.path.getsize(filepath)
        with open(filepath, 'rb') as f:
            with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Indexing {os.path.basename(filepath)}") as pbar:
                while True:
                    current_pos = f.tell()
                    if current_pos >= file_size:
                        break
                    try:
                        # Record the current position
                        drawing_offsets.append(current_pos)
                        # Read one drawing to advance the pointer
                        _ = unpack_drawing(f)
                        pbar.update(f.tell() - current_pos)
                    except struct.error:
                        logger.debug(f"Finished indexing or encountered struct.error at offset {current_pos}")
                        break
                    except EOFError:
                        logger.debug(f"EOFError encountered while indexing at offset {current_pos}")
                        break
                    except Exception as e:
                        logger.error(f"Unexpected error during indexing at offset {current_pos}: {e}")
                        break
        
        # Save to disk
        with open(index_path, 'wb') as f:
            pickle.dump(drawing_offsets, f)
            
        logger.info(f"Saved index with {len(drawing_offsets)} entries to {index_path}")
    
    logger.info(f"Precomputed indices for {len(categories)} categories")
    return True

# --- Custom QuickDraw Dataset from Local Binary Files (REFACTORED) ---
class QuickDrawBinaryDataset(Dataset):
    IMAGE_SIZE = (256, 256)
    LINE_WIDTH = 2
    # Static dictionary to store cached indices for each file path
    _cached_drawing_offsets = {}

    def __init__(self, root, category, transform=None, cache_size=QUICKDRAW_CACHE_SIZE):
        self.root = root
        # Sanitize category name for filename (replace spaces with underscores)
        self.category = category.replace(' ', '_')
        self.transform = transform
        self.filepath = os.path.join(self.root, f"full_binary_{self.category}.bin")
        self.index_path = os.path.join(self.root, f"full_binary_{self.category}.idx")

        self.cache_size = cache_size
        self.worker_caches = {}  # Dictionary to store worker-specific caches

        if not os.path.exists(self.filepath):
            raise FileNotFoundError(
                f"Dataset binary file not found: {self.filepath}. Please ensure it exists."
            )

        # Load or create the index for this file path
        self.drawing_offsets = self._get_or_create_index()

        if not self.drawing_offsets:
            logger.warning(f"No drawings were indexed for category {self.category} from {self.filepath}.")
        else:
            logger.info(f"Successfully loaded or indexed {len(self.drawing_offsets)} drawings for {self.category}. Cache capacity: {self.cache_size} items.")

    def _get_or_create_index(self):
        # First check in-memory cache
        if self.filepath in self._cached_drawing_offsets:
            logger.debug(f"Using in-memory cached index for {self.filepath}")
            return self._cached_drawing_offsets[self.filepath]
        
        # Then check for pre-computed index file
        if os.path.exists(self.index_path):
            try:
                logger.info(f"Loading pre-computed index from {self.index_path}")
                with open(self.index_path, 'rb') as f:
                    drawing_offsets = pickle.load(f)
                # Store in memory cache
                self._cached_drawing_offsets[self.filepath] = drawing_offsets
                return drawing_offsets
            except Exception as e:
                logger.error(f"Error loading pre-computed index from {self.index_path}: {e}")
                # Fall back to creating index
        
        # If we get here, we need to create the index
        logger.info(f"Indexing drawings from {self.filepath} for category {self.category}...")
        drawing_offsets = []

        file_size = os.path.getsize(self.filepath)
        idx_file_handle = None
        try:
            idx_file_handle = open(self.filepath, 'rb')
            with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Indexing {os.path.basename(self.filepath)}", leave=False) as pbar:
                while True:
                    current_pos = idx_file_handle.tell()
                    if current_pos >= file_size:
                        break
                    try:
                        # Read one drawing just to advance the pointer and validate structure
                        # We use the external unpack_drawing function here
                        drawing_offsets.append(current_pos)
                        _ = unpack_drawing(idx_file_handle)
                        pbar.update(idx_file_handle.tell() - current_pos)
                    except struct.error:
                        logger.debug(f"Finished indexing or encountered struct.error at offset {current_pos} in {self.filepath}. Total indexed: {len(drawing_offsets)}")
                        break
                    except EOFError:
                        logger.debug(f"EOFError encountered while indexing {self.filepath} at offset {current_pos}. Total indexed: {len(drawing_offsets)}")
                        break
                    except Exception as e:
                        logger.error(f"Unexpected error during indexing of {self.filepath} at offset {current_pos}: {e}. Stopping indexing for this file.")
                        break
        finally:
            if idx_file_handle:
                idx_file_handle.close()

        # Store the index in the memory cache
        self._cached_drawing_offsets[self.filepath] = drawing_offsets
        
        # Also save to disk for future use
        try:
            with open(self.index_path, 'wb') as f:
                pickle.dump(drawing_offsets, f)
                logger.info(f"Saved index with {len(drawing_offsets)} entries to {self.index_path}")
        except Exception as e:
            logger.error(f"Failed to save index to {self.index_path}: {e}")
        
        return drawing_offsets

    def _render_drawing_to_image(self, drawing_strokes):
        image = Image.new("L", self.IMAGE_SIZE, "white")
        draw = ImageDraw.Draw(image)
        for stroke_x, stroke_y in drawing_strokes:
            if not stroke_x or not stroke_y:
                continue
            if len(stroke_x) == 1:
                draw.point((int(stroke_x[0]), int(stroke_y[0])), fill="black")
            else:
                points = list(zip(stroke_x, stroke_y))
                draw.line(points, fill="black", width=self.LINE_WIDTH)
        return image

    def __len__(self):
        return len(self.drawing_offsets)

    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self.drawing_offsets):
            raise IndexError(f"Index {idx} out of bounds for {len(self.drawing_offsets)} drawings.")

        # Get worker info for process-specific caching
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info else 0
        
        # Create worker-specific cache if it doesn't exist
        if worker_id not in self.worker_caches:
            self.worker_caches[worker_id] = collections.OrderedDict()
        
        # Use worker-specific cache
        worker_cache = self.worker_caches[worker_id]
        
        drawing_data = None
        if idx in worker_cache:
            drawing_data = worker_cache[idx]
            worker_cache.move_to_end(idx)  # Mark as recently used
        else:
            offset = self.drawing_offsets[idx]
            try:
                # Open file, seek, read one drawing, then close.
                # This is safer for multiprocessing in DataLoader.
                with open(self.filepath, 'rb') as f:
                    f.seek(offset)
                    # Use the external unpack_drawing function here
                    drawing_data = unpack_drawing(f)
            except Exception as e:
                logger.error(f"Error reading drawing at index {idx}, offset {offset} from {self.filepath}: {e}")
                raise IOError(f"Failed to load drawing {idx} for {self.category}") from e

            if self.cache_size > 0:
                worker_cache[idx] = drawing_data
                if len(worker_cache) > self.cache_size:
                    worker_cache.popitem(last=False)  # Remove oldest item (LRU)

        if drawing_data is None: # Should not happen if logic is correct
             raise RuntimeError(f"Drawing data for index {idx} could not be retrieved.")

        pil_image = self._render_drawing_to_image(drawing_data['image'])

        if self.transform:
            pil_image = self.transform(pil_image)

        return pil_image, self.category

### Generate CSV of the 51 QuickDraw Categories

In [None]:
def generate_category_counts_report():
    '''
    Generates a CSV report of all QuickDraw categories and their sample counts.
    Generating category counts report...")
    '''
    category_stats = []
    
    for category_name in tqdm(QUICKDRAW_CATEGORIES, desc="Counting category samples"):
        try:
            # Create dataset for this category
            filepath = os.path.join(BINARY_DATA_ROOT, f"full_binary_{category_name.replace(' ', '_')}.bin")
            if not os.path.exists(filepath):
                category_stats.append({
                    'Category': category_name,
                    'Total Samples': 0,
                    'Error': 'File not found'
                })
                continue
                
            dataset = QuickDrawBinaryDataset(
                root=BINARY_DATA_ROOT,
                category=category_name,
                transform=None,
                cache_size=QUICKDRAW_CACHE_SIZE
            )
            
            total_count = len(dataset)
            
            # Calculate the actual split sizes based on our splitting logic
            actual_train = min(NUM_TRAIN_SAMPLES_PER_CATEGORY, int(total_count * 0.7))
            remaining = total_count - actual_train
            actual_val = min(int(NUM_TRAIN_SAMPLES_PER_CATEGORY * 0.2), int(remaining * 0.5))
            actual_test = min(NUM_TEST_SAMPLES_PER_CATEGORY, remaining - actual_val)
            
            category_stats.append({
                'Category': category_name,
                'Total Samples': total_count,
                'Training Samples': actual_train,
                'Validation Samples': actual_val,
                'Test Samples': actual_test
            })
            
            logger.info(f"Category: {category_name}, Total: {total_count}, Train: {actual_train}, Val: {actual_val}, Test: {actual_test}")
            
        except Exception as e:
            logger.error(f"Error counting samples for category {category_name}: {e}")
            category_stats.append({
                'Category': category_name,
                'Total Samples': 0,
                'Error': str(e)
            })
    
    # Save to CSV
    csv_filename = f"quickdraw_category_counts.csv"
    
    try:
        with open(csv_filename, 'w', newline='') as csvfile:
            # Define fieldnames including all possible columns
            fieldnames = ['Category', 'Total Samples', 'Training Samples', 'Validation Samples', 'Test Samples', 'Error']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            for stats in category_stats:
                writer.writerow(stats)
        logger.info(f"Category counts saved to {csv_filename}")
    except IOError as e:
        logger.error(f"Could not save category counts to CSV: {e}")
    
    return category_stats

# Generate category counts report
category_stats = generate_category_counts_report()


##  Data Preparation with Augmentation 

In [None]:
# --- Data Preparation with Augmentation ---
def get_augmented_quickdraw_data(categories, num_train_per_cat, num_test_per_cat, data_root):
    """"
    Enhanced version of get_quickdraw_data that creates train/validation/test splits
    and applies data augmentation to the training set.
    """
    logger.info(f"Loading augmented QuickDraw data for {len(categories)} categories...")

    # Create standard transforms
    base_transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Augmented transform for training
    train_transform = T.Compose([
        T.Grayscale(num_output_channels=3),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomRotation(15),
        T.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        T.ColorJitter(brightness=0.2, contrast=0.2),
        base_transform
    ])

    # Simple transform for validation/testing
    test_transform = T.Compose([
        T.Grayscale(num_output_channels=3),
        base_transform
    ])

    all_train_datasets = []
    all_val_datasets = []
    all_test_datasets = []

    for category_idx, category_name in enumerate(categories):
        try:
            logger.info(f"Loading QuickDraw category: {category_name} from local binary files...")
            full_category_dataset = QuickDrawBinaryDataset(
                root=data_root,
                category=category_name,
                transform=None,  # No transform yet, will be applied per-split
                cache_size=QUICKDRAW_CACHE_SIZE
            )

            if len(full_category_dataset) == 0:
                logger.warning(f"No data loaded for category {category_name}. Skipping.")
                continue

            # Calculate actual samples to use (limited by available data)
            actual_num_train = min(num_train_per_cat, len(full_category_dataset) * 0.7)
            remaining_samples = len(full_category_dataset) - actual_num_train

            # Split remaining samples between validation and test
            actual_num_val = min(num_train_per_cat * 0.2, remaining_samples * 0.5)
            actual_num_test = min(NUM_TEST_SAMPLES_PER_CATEGORY, remaining_samples - actual_num_val)

            actual_num_train = int(actual_num_train)
            actual_num_val = int(actual_num_val)
            actual_num_test = int(actual_num_test)

            logger.info(f"Category {category_name}: {actual_num_train} train, {actual_num_val} val, {actual_num_test} test")

            # Ensure we have enough samples
            if actual_num_train == 0 or actual_num_val == 0 or actual_num_test == 0:
                logger.warning(f"Not enough samples in {category_name} for desired splits. Skipping.")
                continue

            # Create indices for random splits
            indices = list(range(len(full_category_dataset)))
            random.shuffle(indices)

            train_indices = indices[:actual_num_train]
            val_indices = indices[actual_num_train:actual_num_train + actual_num_val]
            test_indices = indices[actual_num_train + actual_num_val:actual_num_train + actual_num_val + actual_num_test]

            # Create labeled datasets with appropriate transforms
            train_subset = TransformedSubset(full_category_dataset, train_indices, train_transform, category_idx)
            val_subset = TransformedSubset(full_category_dataset, val_indices, test_transform, category_idx)
            test_subset = TransformedSubset(full_category_dataset, test_indices, test_transform, category_idx)

            all_train_datasets.append(train_subset)
            all_val_datasets.append(val_subset)
            all_test_datasets.append(test_subset)

        except Exception as e:
            logger.error(f"Error loading category {category_name}: {e}")
            continue

    # Ensure we have data for at least some categories
    if not all_train_datasets or not all_val_datasets or not all_test_datasets:
        if not any(all_train_datasets) and not any(all_val_datasets) and not any(all_test_datasets):
            raise RuntimeError("No QuickDraw data could be loaded for any category. Aborting.")
        else:
            logger.warning("Some categories failed to load, proceeding with available data.")

    # Combine datasets across categories
    train_dataset = ConcatDataset(all_train_datasets)
    val_dataset = ConcatDataset(all_val_datasets)
    test_dataset = ConcatDataset(all_test_datasets)

    logger.info(f"Created datasets with {len(train_dataset)} training samples, "
               f"{len(val_dataset)} validation samples, and {len(test_dataset)} test samples.")

    return train_dataset, val_dataset, test_dataset

# Helper class for applying transforms during subset creation
class TransformedSubset(Dataset):
    """
    A subset of a dataset with a transform that's applied on-the-fly.
    Also assigns a fixed class label for classification tasks.
    """
    def __init__(self, dataset, indices, transform, label):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform
        self.label = label

    def __getitem__(self, idx):
        try:
            image, _ = self.dataset[self.indices[idx]]
            if self.transform:
                image = self.transform(image)
            return image, self.label
        except Exception as e:
            # Fallback to a default/placeholder image or retry
            logger.warning(f"Error loading image at index {idx}: {e}")
            # Create a placeholder image (black square)
            placeholder = torch.zeros(3, 224, 224)
            return placeholder, self.label

    def __len__(self):
        return len(self.indices)

# Model Classes

In [None]:
# --- Model Wrapper Architecture for Ensemble ---
class ModelWrapper(nn.Module):
    """Base wrapper for models to ensure consistent interface in ensemble"""
    def __init__(self, model=None, num_classes=len(QUICKDRAW_CATEGORIES)):
        super().__init__()
        self.model = model
        self.num_classes = num_classes
        
    def forward(self, x):
        return self.model(x)
    
    @classmethod
    def load_from_checkpoint(cls, model_path, device=DEVICE):
        """Load model from checkpoint with proper error handling"""
        try:
            instance = cls()
            state_dict = torch.load(model_path, map_location=device)
            instance.model.load_state_dict(state_dict)
            instance.model.to(device)
            instance.model.eval()
            return instance
        except Exception as e:
            logger.error(f"Failed to load model from {model_path}: {e}")
            return None

class ShuffleNetV2Wrapper(ModelWrapper):
    def __init__(self, weights=None, num_classes=len(QUICKDRAW_CATEGORIES)):
        super().__init__(None, num_classes)
        self.model = models.shufflenet_v2_x0_5(weights=weights)
        in_features = self.model.fc.in_features
        
        # Replace with optimized classifier for better accuracy
        self.model.fc = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    @classmethod
    def load_from_checkpoint(cls, model_path, device=DEVICE):
        instance = cls(weights=None)
        try:
            state_dict = torch.load(model_path, map_location=device)
            instance.model.load_state_dict(state_dict)
            instance.model.to(device)
            instance.model.eval()
            return instance
        except Exception as e:
            logger.error(f"Failed to load ShuffleNetV2 from {model_path}: {e}")
            return None

class MobileNetV3Wrapper(ModelWrapper):
    def __init__(self, weights=None, num_classes=len(QUICKDRAW_CATEGORIES)):
        super().__init__(None, num_classes)
        self.model = models.mobilenet_v3_small(weights=weights)
        in_features = self.model.classifier[0].in_features
        
        self.model.classifier = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    @classmethod
    def load_from_checkpoint(cls, model_path, device=DEVICE):
        instance = cls(weights=None)
        try:
            state_dict = torch.load(model_path, map_location=device)
            instance.model.load_state_dict(state_dict)
            instance.model.to(device)
            instance.model.eval()
            return instance
        except Exception as e:
            logger.error(f"Failed to load MobileNetV3 from {model_path}: {e}")
            return None

class SqueezeNetWrapper(ModelWrapper):
    def __init__(self, weights=None, num_classes=len(QUICKDRAW_CATEGORIES)):
        super().__init__(None, num_classes)
        self.model = models.squeezenet1_1(weights=weights)
        in_channels = self.model.classifier[1].in_channels
        
        # Use the specialized SqueezeNet classifier format
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Conv2d(in_channels, 512, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )
    
    @classmethod
    def load_from_checkpoint(cls, model_path, device=DEVICE):
        instance = cls(weights=None)
        try:
            state_dict = torch.load(model_path, map_location=device)
            instance.model.load_state_dict(state_dict)
            instance.model.to(device)
            instance.model.eval()
            return instance
        except Exception as e:
            logger.error(f"Failed to load SqueezeNet from {model_path}: {e}")
            return None

### Utility: Checkpointing during Training

In [None]:
# --- Checkpoint Utility Functions ---
def save_checkpoint(model, optimizer, scheduler, epoch, best_val_accuracy, early_stopping_counter, model_name, is_best=False):
    """
    Save a checkpoint of the model, optimizer, scheduler and training state with git info.
    """
    # Create checkpoint directory if it doesn't exist
    checkpoint_dir = os.path.join(MODEL_SAVE_PATH, "checkpoints", model_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Get git info
    git_info = get_git_info()
    
    # Prepare checkpoint data
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
        'best_val_accuracy': best_val_accuracy,
        'early_stopping_counter': early_stopping_counter,
        'git_info': git_info
    }
    
    # Create parameter string for filename
    params_str = f"samples{NUM_TRAIN_SAMPLES_PER_CATEGORY}_epochs{NUM_FINETUNE_EPOCHS}"
    if git_info["available"]:
        git_str = f"_{git_info['message'][:10]}..."
    else:
        git_str = ""
    
    # Save latest epoch checkpoint (always overwriting previous)
    latest_path = os.path.join(checkpoint_dir, f'checkpoint_latest{git_str}_{params_str}.pth')
    torch.save(checkpoint, latest_path)
    logger.info(f"Latest checkpoint saved at {latest_path}")
    
    # If this is the best model so far, save it separately
    if is_best:
        best_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}_best{git_str}_{params_str}.pth")
        torch.save(model.state_dict(), best_path)
        logger.info(f"Best model saved at {best_path}")
        
        # Also save a complete checkpoint for the best model
        best_checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_best{git_str}_{params_str}.pth')
        torch.save(checkpoint, best_checkpoint_path)

def load_checkpoint(model, optimizer, scheduler, model_name):
    """
    Load the latest checkpoint if it exists.
    """
    checkpoint_dir = os.path.join(MODEL_SAVE_PATH, "checkpoints", model_name)
    
    # Get git info
    git_info = get_git_info()
    
    # Create parameter string for filename
    params_str = f"samples{NUM_TRAIN_SAMPLES_PER_CATEGORY}_epochs{NUM_FINETUNE_EPOCHS}"
    if git_info["available"]:
        git_str = f"_{git_info['message'][:10]}..."
    else:
        git_str = ""
    
    latest_path = os.path.join(checkpoint_dir, f'checkpoint_latest{git_str}_{params_str}.pth')
    
    # If the checkpoint file doesn't exist, return initial values
    if not os.path.exists(latest_path):
        logger.info(f"No checkpoint found at {latest_path}, starting from scratch.")
        return 0, 0.0, 0
    
    # Load the checkpoint
    logger.info(f"Loading checkpoint from {latest_path}")
    try:
        checkpoint = torch.load(latest_path, map_location=DEVICE)
        
        # Load model weights
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Load optimizer state
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Load scheduler state if it exists
        if scheduler is not None and checkpoint['scheduler_state_dict'] is not None:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        # Extract and return training state
        epoch = checkpoint['epoch']
        best_val_accuracy = checkpoint['best_val_accuracy']
        early_stopping_counter = checkpoint['early_stopping_counter']
        
        logger.info(f"Resuming from epoch {epoch+1} with best validation accuracy: {best_val_accuracy:.2f}%")
        return epoch + 1, best_val_accuracy, early_stopping_counter
        
    except Exception as e:
        logger.error(f"Error loading checkpoint: {e}")
        return 0, 0.0, 0

# Template code for resuming specific model training
def resume_specific_model_training(model_name):
    # Set up model and optimizer as in the benchmark function
    model_config = MODELS_TO_TEST[model_name]
    model = model_config["model_fn"](weights=model_config["weights"])
    
    # Set up the classifier based on model type
    # ... (same code as in run_finetuning_benchmark)
    
    # Initialize optimizer and scheduler
    optimizer = optim.Adam(model.parameters(), lr=FINETUNE_LEARNING_RATE, weight_decay=FINETUNE_WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
    
    # Load checkpoint
    start_epoch, best_val_accuracy, early_stopping_counter = load_checkpoint(
        model, optimizer, scheduler, model_name
    )
    
    if start_epoch == 0:
        logger.info(f"No checkpoint found for {model_name}. Starting from scratch.")
    else:
        logger.info(f"Resuming training for {model_name} from epoch {start_epoch}")
    
    # Continue with training loop as in the benchmark function
    # ...

In [None]:
def save_benchmark_results(results, benchmark_type='finetuning', base_dir='./results'):
    """
    Saves benchmark results to JSON and CSV files with git commit info and parameters.
    """
    # Create results directory if it doesn't exist
    os.makedirs(base_dir, exist_ok=True)
    
    # Get git info
    git_info = get_git_info()
    
    # Generate filename with parameters instead of timestamp
    params_str = f"samples{NUM_TRAIN_SAMPLES_PER_CATEGORY}_epochs{NUM_FINETUNE_EPOCHS}"
    if git_info["available"]:
        git_str = f"_{git_info['message'][:10]}..."
    else:
        git_str = ""
    
    base_filename = f"{benchmark_type}{git_str}_{params_str}"
    
    # Prepare paths
    json_path = os.path.join(base_dir, f"{base_filename}.json")
    csv_path = os.path.join(base_dir, f"{base_filename}.csv")
    
    # Add metadata
    metadata = {
        "git_info": git_info,
        "device": str(DEVICE),
        "num_categories": len(QUICKDRAW_CATEGORIES),
        "samples_per_category": NUM_TRAIN_SAMPLES_PER_CATEGORY,
        "test_samples_per_category": NUM_TEST_SAMPLES_PER_CATEGORY,
        "finetune_epochs": NUM_FINETUNE_EPOCHS
    }
    
    # Save JSON with metadata
    with open(json_path, 'w') as f:
        json_data = {
            "metadata": metadata,
            "results": results
        }
        json.dump(json_data, f, indent=2)
    
    # Save CSV
    if results and len(results) > 0:
        fieldnames = set()
        for result in results:
            fieldnames.update(result.keys())
        
        with open(csv_path, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=sorted(fieldnames))
            writer.writeheader()
            writer.writerows(results)
    
    logger.info(f"Benchmark results saved to {json_path} and {csv_path}")
    return json_path, csv_path

# Benchmarks

### Benchmark Lightweight models for Fine-tuning

In [None]:
# --- Enhanced Fine-tuning Benchmarking Loop (Part 1) ---
import copy
import random
import torch.optim as optim
import torch.nn as nn


def run_finetuning_benchmark():
    finetuning_results = []

    # Get data with augmentation
    train_dataset, val_dataset, test_dataset = get_augmented_quickdraw_data(
        QUICKDRAW_CATEGORIES,
        NUM_TRAIN_SAMPLES_PER_CATEGORY,
        NUM_TEST_SAMPLES_PER_CATEGORY,
        BINARY_DATA_ROOT
    )

    # In run_finetuning_benchmark function:
    # Calculate adaptive batch size based on dataset size
    effective_batch_size = get_adaptive_batch_size(
        NUM_TRAIN_SAMPLES_PER_CATEGORY, len(QUICKDRAW_CATEGORIES))
    logger.info(
        f"Using adaptive batch size of {effective_batch_size} for feature extraction")

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=effective_batch_size,
                              shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=3)
    val_loader = DataLoader(val_dataset, batch_size=effective_batch_size,
                            shuffle=False, num_workers=4, pin_memory=True, prefetch_factor=3)
    test_loader = DataLoader(test_dataset, batch_size=effective_batch_size,
                             shuffle=False, num_workers=4, pin_memory=True, prefetch_factor=3)

    logger.info(f"\n--- Starting Enhanced Fine-tuning Benchmark ---")
    logger.info(f"Using device: {DEVICE}")
    logger.info(
        f"Fine-tuning on QuickDraw categories: {', '.join(QUICKDRAW_CATEGORIES)}")
    logger.info(
        f"Samples per category: {NUM_TRAIN_SAMPLES_PER_CATEGORY} train, {NUM_TEST_SAMPLES_PER_CATEGORY} test.")
    logger.info(
        f"Number of fine-tuning epochs: {NUM_FINETUNE_EPOCHS}, Learning Rate: {FINETUNE_LEARNING_RATE}")
    logger.info(
        f"Weight Decay: {FINETUNE_WEIGHT_DECAY}, Gradual Unfreezing: {USE_GRADUAL_UNFREEZING}")
    logger.info(
        f"Loading QuickDraw data from local directory: {os.path.abspath(BINARY_DATA_ROOT)}\n")

    num_classes = len(QUICKDRAW_CATEGORIES)

    for model_name, config in tqdm(MODELS_TO_TEST.items(), desc="Benchmarking Models (Enhanced Fine-tuning)", unit="model"):
        logger.info(f"--- Enhanced Fine-tuning Model: {model_name} ---")

        # Clear GPU memory before loading a new model
        clear_gpu_memory()

        current_ft_accuracy = "Error"
        current_ft_train_time = "N/A"
        current_ft_inference_time = "N/A"
        current_ft_inference_speed = "N/A"
        model_params_ft = "Error"
        current_model_path = "Not saved"
        best_val_accuracy = 0.0

        try:
            weights = config["weights"]
            model_to_finetune = config["model_fn"](weights=weights)
            model_params_ft = sum(
                p.numel() for p in model_to_finetune.parameters() if p.requires_grad) / 1_000_000

            # Create custom classifier heads with improved architectures
            if model_name == "MobileNetV3-Small":
                # Save the feature extraction part
                features = model_to_finetune.features
                avgpool = model_to_finetune.avgpool

                # Replace classifier with a better one
                in_features = model_to_finetune.classifier[0].in_features
                model_to_finetune.classifier = nn.Sequential(
                    nn.Linear(in_features, 1024),
                    nn.BatchNorm1d(1024),
                    nn.ReLU(inplace=True),
                    nn.Dropout(0.4),
                    nn.Linear(1024, 512),
                    nn.BatchNorm1d(512),
                    nn.ReLU(inplace=True),
                    nn.Dropout(0.3),
                    nn.Linear(512, num_classes)
                )

                # Define the feature layers for gradual unfreezing
                feature_layers = [features[0], features[1], features[2], features[3],
                                  features[4], features[5], features[6], features[7],
                                  features[8], features[9], features[10], features[11],
                                  avgpool]

            elif model_name == "SqueezeNet1_1":
                # Get the input channels of the classifier
                in_channels = model_to_finetune.classifier[1].in_channels

                # More sophisticated conv classifier
                model_to_finetune.classifier = nn.Sequential(
                    nn.Dropout(p=0.5),
                    nn.Conv2d(in_channels, 512, kernel_size=1),
                    nn.ReLU(inplace=True),
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Conv2d(512, num_classes, kernel_size=1)
                )

                # Define feature layers for gradual unfreezing
                feature_layers = [
                    model_to_finetune.features[0],  # conv1
                    model_to_finetune.features[1],  # maxpool
                    model_to_finetune.features[2],  # fire1
                    model_to_finetune.features[3],  # fire2
                    model_to_finetune.features[4],  # fire3
                    model_to_finetune.features[5],  # fire4
                    model_to_finetune.features[6],  # maxpool
                    model_to_finetune.features[7],  # fire5
                    model_to_finetune.features[8],  # fire6
                    model_to_finetune.features[9],  # fire7
                    model_to_finetune.features[10],  # fire8
                    model_to_finetune.features[11],  # maxpool
                    model_to_finetune.features[12]  # conv10
                ]
            elif model_name == "ShuffleNetV2_x0_5":
                # Get the number of input features for the classifier
                in_features = model_to_finetune.fc.in_features

                # Replace the classifier with an optimized version for better learning
                model_to_finetune.fc = nn.Sequential(
                    nn.Linear(in_features, 1024),
                    nn.BatchNorm1d(1024),
                    nn.ReLU(inplace=True),
                    nn.Dropout(0.4),
                    nn.Linear(1024, 512),
                    nn.BatchNorm1d(512),
                    nn.ReLU(inplace=True),
                    nn.Dropout(0.3),
                    nn.Linear(512, num_classes)
                )

                # Define feature layers for gradual unfreezing
                # This progressive unfreezing helps with learning transfer
                feature_layers = [
                    model_to_finetune.conv1,
                    model_to_finetune.maxpool,
                    model_to_finetune.stage2,
                    model_to_finetune.stage3,
                    model_to_finetune.stage4,
                    model_to_finetune.conv5
                ]

                # Apply a higher initial learning rate for ShuffleNetV2
                # This helps overcome the vanishing gradient problem in lightweight models
                optimizer = optim.Adam(model_to_finetune.parameters(),
                                       lr=FINETUNE_LEARNING_RATE * 1.5,
                                       weight_decay=FINETUNE_WEIGHT_DECAY * 0.8)
            else:
                logger.warning(
                    f"Classifier modification not defined for {model_name}. Skipping fine-tuning.")
                continue

            model_to_finetune.to(DEVICE)

            # Enable gradient checkpointing if configured
            if USE_GRADIENT_CHECKPOINTING and hasattr(model_to_finetune, 'features'):
                logger.info(f"Enabling gradient checkpointing for {model_name}")
                try:
                    # Check if the features module has gradient_checkpointing_enable
                    if hasattr(model_to_finetune.features, 'gradient_checkpointing_enable'):
                        model_to_finetune.features.gradient_checkpointing_enable()
                    # For models where features is a Sequential module
                    elif isinstance(model_to_finetune.features, nn.Sequential):
                        # Skip gradient checkpointing for Sequential modules
                        logger.info(f"Gradient checkpointing not available for {model_name} with Sequential features")
                    else:
                        logger.info(f"Gradient checkpointing not supported for this model architecture")
                except Exception as e:
                    logger.warning(f"Failed to enable gradient checkpointing for {model_name}: {e}")
                    # Continue without gradient checkpointing
            # --- Enhanced Fine-tuning Procedure ---
            logger.info(f"Starting fine-tuning for {model_name}...")
            start_time = time.time()

            # Initialize optimizer, criterion, and scheduler
            optimizer = optim.Adam(model_to_finetune.parameters(
            ), lr=FINETUNE_LEARNING_RATE, weight_decay=FINETUNE_WEIGHT_DECAY)
            criterion = nn.CrossEntropyLoss()
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='max', factor=0.5, patience=2, verbose=True)

            # Initialize early stopping variables
            early_stopping_patience = 3
            early_stopping_counter = 0

            # Check for existing checkpoints and resume if available
            start_epoch = 0
            if RESUME_FROM_CHECKPOINT:
                start_epoch, best_val_accuracy, early_stopping_counter = load_checkpoint(
                    model_to_finetune, optimizer, scheduler, model_name
                )

            # If we're starting from a checkpoint, we need to restore the best model weights
            if start_epoch > 0:
                best_model_wts = copy.deepcopy(model_to_finetune.state_dict())
            else:
                best_model_wts = copy.deepcopy(model_to_finetune.state_dict())

            # Fine-tuning loop
            for epoch in range(start_epoch, NUM_FINETUNE_EPOCHS):
                logger.info(f"Epoch {epoch + 1}/{NUM_FINETUNE_EPOCHS}")

                # Adjust learning rate if using gradual unfreezing
                if USE_GRADUAL_UNFREEZING:
                    if epoch < 10:
                        for param in feature_layers[:epoch + 1]:
                            for p in param.parameters():
                                p.requires_grad = True
                    else:
                        for param in feature_layers:
                            for p in param.parameters():
                                p.requires_grad = True

                # Training phase
                model_to_finetune.train()
                train_loss = 0.0
                train_correct = 0
                train_total = 0

                # Calculate needed gradient accumulation steps based on effective batch size
                if effective_batch_size < MAX_BATCH_SIZE:
                    gradient_accumulation_steps = max(
                        1, MAX_BATCH_SIZE // effective_batch_size)
                    logger.info(
                        f"Using gradient accumulation with {gradient_accumulation_steps} steps")
                else:
                    gradient_accumulation_steps = 1

                # Zero gradients at the beginning of epoch
                optimizer.zero_grad()

                for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc="Training", leave=False)):
                    images, labels = images.to(DEVICE), labels.to(DEVICE)

                    # Forward pass
                    outputs = model_to_finetune(images)
                    loss = criterion(outputs, labels) / gradient_accumulation_steps  # Normalize loss

                    # Backward pass
                    loss.backward()

                    # Only update weights after accumulating gradients for specified steps
                    if (batch_idx + 1) % gradient_accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                        optimizer.step()
                        optimizer.zero_grad()

                    train_loss += loss.item() * images.size(0) * gradient_accumulation_steps  # Scale loss back for reporting
                    _, predicted = torch.max(outputs, 1)
                    train_correct += (predicted == labels).sum().item()
                    train_total += labels.size(0)

                # Calculate average loss and accuracy for the epoch
                epoch_train_loss = train_loss / train_total
                epoch_train_accuracy = 100 * train_correct / train_total
                logger.info(
                    f"Train Loss: {epoch_train_loss:.4f}, Train Accuracy: {epoch_train_accuracy:.2f}%")

                # Validation phase
                model_to_finetune.eval()
                val_loss = 0.0
                val_correct = 0
                val_total = 0

                with torch.no_grad():
                    for images, labels in tqdm(val_loader, desc="Validation", leave=False):
                        images, labels = images.to(DEVICE), labels.to(DEVICE)

                        # Forward pass
                        outputs = model_to_finetune(images)
                        loss = criterion(outputs, labels)

                        val_loss += loss.item() * images.size(0)
                        _, predicted = torch.max(outputs, 1)
                        val_correct += (predicted == labels).sum().item()
                        val_total += labels.size(0)

                # Calculate average loss and accuracy for the validation set
                epoch_val_loss = val_loss / val_total
                epoch_val_accuracy = 100 * val_correct / val_total
                logger.info(
                    f"Val Loss: {epoch_val_loss:.4f}, Val Accuracy: {epoch_val_accuracy:.2f}%")

                # Clear GPU memory after each epoch
                clear_gpu_memory()

                # Early stopping logic
                is_best = False
                if epoch_val_accuracy > best_val_accuracy:
                    best_val_accuracy = epoch_val_accuracy
                    best_model_wts = copy.deepcopy(
                        model_to_finetune.state_dict())
                    logger.info(
                        f"New best model found for {model_name}! (Val Accuracy: {best_val_accuracy:.2f}%)")
                    early_stopping_counter = 0
                    is_best = True
                else:
                    early_stopping_counter += 1
                    logger.info(
                        f"Validation accuracy didn't improve. Counter: {early_stopping_counter}/{early_stopping_patience}")

                # Save checkpoint at specified intervals or if it's the best model
                if (epoch + 1) % CHECKPOINT_INTERVAL == 0 or is_best:
                    save_checkpoint(
                        model_to_finetune,
                        optimizer,
                        scheduler,
                        epoch,
                        best_val_accuracy,
                        early_stopping_counter,
                        model_name,
                        is_best=is_best
                    )

                # Check if early stopping criteria is met
                if early_stopping_counter >= early_stopping_patience:
                    logger.info(
                        f"Early stopping triggered after {epoch+1} epochs")
                    break

                # Learning rate scheduler
                scheduler.step(epoch_val_accuracy)
                logger.info(
                    f"Current learning rate: {optimizer.param_groups[0]['lr']:.6f}")

            # Load the best model weights
            model_to_finetune.load_state_dict(best_model_wts)

            # Now save the final model (end of training)
            final_model_path = f"{MODEL_SAVE_PATH}/{model_name}_final.pth"
            # Store current state before overwriting
            final_model_wts = copy.deepcopy(model_to_finetune.state_dict())
            # Save the final state
            torch.save(final_model_wts, final_model_path)
            logger.info(f"Final model saved: {final_model_path}")

            # Record total training time
            training_time = time.time() - start_time
            current_ft_train_time = f"{training_time:.2f}"
            logger.info(f"Total training time: {current_ft_train_time}s")

            # Evaluate the model on the test set
            test_loss = 0.0
            test_correct = 0
            test_total = 0

            # Track inference time
            start_inference_time = time.time()
            with torch.no_grad():
                for images, labels in tqdm(test_loader, desc="Testing", leave=False):
                    images, labels = images.to(DEVICE), labels.to(DEVICE)

                    # Forward pass
                    outputs = model_to_finetune(images)
                    loss = criterion(outputs, labels)

                    test_loss += loss.item() * images.size(0)
                    _, predicted = torch.max(outputs, 1)
                    test_correct += (predicted == labels).sum().item()
                    test_total += labels.size(0)

            inference_time = time.time() - start_inference_time
            current_ft_inference_time = f"{inference_time:.4f}"
            current_ft_inference_speed = f"{test_total / inference_time:.2f}"

            # Calculate average loss and accuracy for the test set
            test_loss /= test_total
            test_accuracy = 100 * test_correct / test_total
            logger.info(
                f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

            # Append the results for this model to the overall results
            finetuning_results.append({
                'model_name': model_name,
                'test_loss': test_loss,
                'test_accuracy': test_accuracy,
                'params': model_params_ft,
                'train_time': current_ft_train_time,
                'inference_time': current_ft_inference_time,
                'inference_speed': current_ft_inference_speed,
                'best_val_accuracy': best_val_accuracy
            })
        except Exception as e:
            logger.error(f"Error occurred while fine-tuning {model_name}: {e}")
            finetuning_results.append({
                'model_name': model_name,
                'test_loss': float('nan'),
                'test_accuracy': 0,
                'params': model_params_ft if isinstance(model_params_ft, (int, float)) else float('nan'),
                'train_time': 'N/A',
                'inference_time': 'N/A',
                'inference_speed': 'N/A',
                'best_val_accuracy': 0,
                'error': str(e)
            })

    # --- Benchmarking Results ---
    logger.info(f"\n--- Enhanced Fine-tuning Benchmarking Results ---")
    for result in finetuning_results:
        if 'error' in result:
            logger.info(
                f"Model: {result['model_name']}, Error: {result['error']}")
        else:
            logger.info(f"Model: {result['model_name']}, Test Loss: {result['test_loss']:.4f}, Test Accuracy: {result['test_accuracy']:.2f}%, Params: {result['params']:.2f}M, Train Time: {result['train_time']}, Inference Time: {result['inference_time']}, Inference Speed: {result['inference_speed']}")

    logger.info(f"Best model(s) based on test accuracy:")
    best_models = sorted(finetuning_results, key=lambda x: x.get(
        'test_accuracy', 0), reverse=True)[:3]
    for bm in best_models:
        logger.info(f" - {bm['model_name']}: {bm['test_accuracy']:.2f}%")

    logger.info(f"Best model(s) based on validation accuracy:")
    best_val_models = sorted(finetuning_results, key=lambda x: x.get(
        'best_val_accuracy', 0), reverse=True)[:3]
    for bvm in best_val_models:
        logger.info(f" - {bvm['model_name']}: {bvm['best_val_accuracy']:.2f}%")

    logger.info(f"\n--- Enhanced Fine-tuning Benchmark Completed ---")

    save_benchmark_results(finetuning_results, benchmark_type='finetuning')
    return finetuning_results

In [None]:
# Run Benchmark for Fine-tuning
run_finetuning_benchmark()

### Save Models with Classes

In [None]:

# Create a mapping from indices to class names for inference
IDX_TO_CLASS = {i: category for i, category in enumerate(QUICKDRAW_CATEGORIES)}

# Inference wrapper that includes class names
class ClassNameInferenceWrapper:
    def __init__(self, model, idx_to_class=None):
        self.model = model
        # Include QUICKDRAW_CATEGORIES for completeness
        self.QUICKDRAW_CATEGORIES = QUICKDRAW_CATEGORIES
        self.idx_to_class = idx_to_class or IDX_TO_CLASS
        self.model.eval()  # Set to evaluation mode
        
    def predict(self, inputs):
        with torch.no_grad():
            outputs = self.model(inputs)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            _, predicted_indices = torch.max(outputs, 1)
            
            # Convert to numpy for easier handling
            predicted_indices = predicted_indices.cpu().numpy()
            probabilities = probabilities.cpu().numpy()
            
            # Map indices to class names
            predicted_names = [self.idx_to_class.get(idx, "Unknown") for idx in predicted_indices]
            
            return {
                'class_idx': predicted_indices,
                'class_name': predicted_names,
                'probabilities': probabilities
            }
    
    def predict_single(self, input_tensor):
        # Add batch dimension if needed
        if input_tensor.dim() == 3:
            input_tensor = input_tensor.unsqueeze(0)
        
        result = self.predict(input_tensor)
        
        # Return just the first result since it's a single image
        return {
            'class_idx': result['class_idx'][0],
            'class_name': result['class_name'][0],
            'probabilities': result['probabilities'][0]
        }
    
def save_model_with_classes(model, model_path, class_names):
    """Save model weights along with class information for portability"""
    # Save model weights
    torch.save(model.state_dict(), model_path)
    
    # Save class information
    class_info_path = model_path.replace('.pth', '_classes.json')
    with open(class_info_path, 'w') as f:
        json.dump(class_names, f)
    
    logger.info(f"Model saved to {model_path}")
    logger.info(f"Class information saved to {class_info_path}")
    
    return model_path, class_info_path

### Class-wise Evaluation Metrics

In [None]:
def evaluate_model_by_class(model, test_loader, class_names):
    """Evaluate model performance for each class separately"""
    model.eval()
    class_correct = [0] * len(class_names)
    class_total = [0] * len(class_names)
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating by class"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            # Calculate class-wise accuracy
            for i in range(len(labels)):
                label = labels[i].item()
                class_correct[label] += (predicted[i] == label).item()
                class_total[label] += 1
    
    # logger.info and return results
    logger.info("\nClass-wise Accuracy:")
    class_accuracies = {}
    for i in range(len(class_names)):
        if class_total[i] > 0:
            accuracy = 100 * class_correct[i] / class_total[i]
            logger.info(f"{class_names[i]}: {accuracy:.2f}%")
            class_accuracies[class_names[i]] = accuracy
    
    return class_accuracies

### Inference

In [None]:
# Load a test image for inference
# Let's get a sample from the test dataset
_, _, test_dataset = get_augmented_quickdraw_data(
    QUICKDRAW_CATEGORIES[:5],  # Using just a few categories for faster loading
    10,  # Small number of samples per category
    5,
    BINARY_DATA_ROOT
)

# Get a batch of test images
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)
images, labels = next(iter(test_loader))

# Use the class name wrapper for inference
# Example with a saved model:
model = models.mobilenet_v3_small(weights=None)
# Set up classifier for the number of classes we have
in_features = model.classifier[0].in_features
model.classifier = nn.Sequential(
    nn.Linear(in_features, 1024),
    nn.BatchNorm1d(1024),
    nn.ReLU(inplace=True),
    nn.Dropout(0.4),
    nn.Linear(1024, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.3),
    nn.Linear(512, len(QUICKDRAW_CATEGORIES))
)

# Try to load the model if it exists
try:
    model.load_state_dict(torch.load(f"{MODEL_SAVE_PATH}/MobileNetV3-Small_best.pth"))
    logger.info(f"Successfully loaded model from {MODEL_SAVE_PATH}/MobileNetV3-Small_best.pth")
except Exception as e:
    logger.info(f"Could not load model: {e}. Using untrained model for demonstration.")

model.to(DEVICE)

# Wrap the model
inference_wrapper = ClassNameInferenceWrapper(model)

# Move images to device
images = images.to(DEVICE)

# Make predictions with class names
result = inference_wrapper.predict(images)
logger.info(f"Predicted classes: {result['class_name']}")
logger.info(f"True labels: {[QUICKDRAW_CATEGORIES[label.item()] for label in labels]}")

In [None]:
def save_sample_images(images, labels, predictions, class_names, save_dir="./sample_images"):
    """Save sample images with their true and predicted labels"""
    # Create directory if it doesn't exist
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Convert tensor images to PIL for saving
    for i in range(min(5, len(images))):
        # Convert tensor to PIL image
        img_tensor = images[i].cpu()
        img = T.ToPILImage()(img_tensor)
        
        true_label = class_names[labels[i].item()]
        pred_label = predictions['class_name'][i]
        
        # Save image with informative filename
        filename = f"{save_dir}/sample_{i+1}_{true_label}_pred_{pred_label}.png"
        img.save(filename)
        logger.info(f"Saved sample image to {filename}")

### Trying Ensemble of Models

In [None]:
import torch.nn.functional as F

class EnhancedEnsembleModel(nn.Module):
    def __init__(self, models, weights=None, device=DEVICE, method='average'):
        super().__init__()
        self.models = models
        # Initialize with equal weights if none provided
        self.weights = weights if weights is not None else [1.0/len(models)] * len(models)
        self.device = device
        self.method = method  # 'average' or 'stack'
        
        # Register weights as a parameter so they can be optimized
        self.learned_weights = nn.Parameter(torch.tensor(self.weights, device=device))
        
        # For stacking method, add a meta-classifier
        if self.method == 'stack':
            # Input size: number of classes * number of models
            # Each model produces probabilities for each class
            input_size = len(QUICKDRAW_CATEGORIES) * len(models)
            self.meta_classifier = nn.Sequential(
                nn.Linear(input_size, 512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, len(QUICKDRAW_CATEGORIES))
            )
        
    def to(self, device):
        """Properly move all components to the specified device"""
        self.device = device
        for model in self.models:
            model.to(device)
        if hasattr(self, 'meta_classifier'):
            self.meta_classifier.to(device)
        return super().to(device)  # This moves the learned_weights parameter
        
    def eval(self):
        for model in self.models:
            model.eval()
        super().eval()
            
    def train(self, mode=True):
        for model in self.models:
            model.train(mode)
        super().train(mode)
    
    def forward(self, x):
        if self.method == 'average':
            return self._forward_average(x)
        elif self.method == 'stack':
            return self._forward_stack(x)
        else:
            raise ValueError(f"Unknown ensemble method: {self.method}")
    
    def _forward_average(self, x):
        """Weighted averaging of model outputs"""
        # Apply softmax to learned weights
        normalized_weights = F.softmax(self.learned_weights, dim=0)
        
        # Apply torch.set_grad_enabled based on training mode
        with torch.set_grad_enabled(self.training):
            outputs = []
            for i, model in enumerate(self.models):
                output = model(x)
                # Apply softmax to get probabilities
                probs = F.softmax(output, dim=1)
                # Apply weight for this model
                outputs.append(probs * normalized_weights[i])
            
            # Sum the weighted outputs
            combined = torch.stack(outputs).sum(dim=0)
            # Convert back to logits for compatibility with CrossEntropyLoss
            return torch.log(combined + 1e-8)
    
    def _forward_stack(self, x):
        """Stacking method - use a meta-classifier on concatenated model outputs"""
        all_probs = []
        
        # Get predictions from all models
        for model in self.models:
            with torch.set_grad_enabled(self.training):
                output = model(x)
                probs = F.softmax(output, dim=1)
                all_probs.append(probs)
        
        # Concatenate all probabilities into a single feature vector
        combined = torch.cat(all_probs, dim=1)
        
        # Ensure combined tensor is on the same device as the meta-classifier
        # This is the key fix for the device mismatch error
        combined = combined.to(self.device)
        
        # Feed through meta-classifier
        return self.meta_classifier(combined)

def export_ensemble_for_deployment(ensemble, model_path):
    """Create a portable ensemble model package with all required components"""
    # Get git info
    git_info = get_git_info()
    
    # Create a dictionary containing all necessary information
    export_data = {
        'state_dict': ensemble.state_dict(),
        'model_weights': [m.state_dict() for m in ensemble.models],
        'learned_weights': ensemble.learned_weights.detach().cpu().numpy().tolist(),
        'class_names': QUICKDRAW_CATEGORIES,
        'git_info': git_info,
        'samples_per_category': NUM_TRAIN_SAMPLES_PER_CATEGORY,
        'finetune_epochs': NUM_FINETUNE_EPOCHS
    }
    
    # Add git and parameter info to filename
    params_str = f"samples{NUM_TRAIN_SAMPLES_PER_CATEGORY}_epochs{NUM_FINETUNE_EPOCHS}"
    if git_info["available"]:
        git_str = f"_{git_info['message'][:10]}..."
    else:
        git_str = ""
    
    # Replace the original filename with the new format
    model_path = model_path.replace('.pth', f'{git_str}_{params_str}.pth')
    
    # Save to file
    torch.save(export_data, model_path)
    logger.info(f"Portable ensemble saved to {model_path}")
    
    return model_path

def load_ensemble_for_inference(model_path):
    """Load a portable ensemble model for inference"""
    # Load the export data
    export_data = torch.load(model_path, map_location=DEVICE)
    
    # Recreate the model architectures
    models = []
    models.append(MobileNetV3Wrapper(num_classes=len(export_data['class_names'])).model)
    models.append(SqueezeNetWrapper(num_classes=len(export_data['class_names'])).model)
    if len(export_data['model_weights']) > 2:
        models.append(ShuffleNetV2Wrapper(num_classes=len(export_data['class_names'])).model)
    
    # Load individual model weights
    for model, weights in zip(models, export_data['model_weights']):
        model.load_state_dict(weights)
    
    # Create ensemble
    ensemble = EnhancedEnsembleModel(models, weights=export_data['learned_weights'])
    ensemble.load_state_dict(export_data['state_dict'])
    
    return ensemble, export_data['class_names']

### Using Model Stack Ensembling Approach 

In [None]:
def train_ensemble_weights(ensemble, epochs=5):
    """Train the ensemble weights using a small validation set"""
    # Load a small validation set
    _, val_dataset, _ = get_augmented_quickdraw_data(
        QUICKDRAW_CATEGORIES,
        NUM_TRAIN_SAMPLES_PER_CATEGORY // 10,  # Use a smaller subset
        NUM_TEST_SAMPLES_PER_CATEGORY // 5,
        BINARY_DATA_ROOT
    )
    
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True, num_workers=2)
    
    # Only train the ensemble weights, not the individual models
    for model in ensemble.models:
        for param in model.parameters():
            param.requires_grad = False
    
    # Set ensemble to training mode
    ensemble.train()
    
    # Use optimizer only for the learned weights
    optimizer = optim.Adam([ensemble.learned_weights], lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    logger.info(f"Training ensemble weights for {epochs} epochs...")
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = ensemble(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        
        # logger.info epoch statistics
        epoch_loss = running_loss / total
        epoch_acc = 100 * correct / total
        logger.info(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%")
        
        # logger.info current weights
        normalized_weights = F.softmax(ensemble.learned_weights, dim=0).cpu().detach().numpy()
        weight_str = ", ".join([f"{w:.4f}" for w in normalized_weights])
        logger.info(f"Current weights: [{weight_str}]")
    
    # Switch back to evaluation mode
    ensemble.eval()
    return ensemble

### Ensemble Evaluation

In [None]:
def load_model_wrapper(wrapper_class, model_name):
    """Helper to load a model with proper error handling"""
    try:
        model_path = f"{MODEL_SAVE_PATH}/{model_name}_best.pth"
        wrapper = wrapper_class()
        wrapper.model.load_state_dict(torch.load(model_path))
        wrapper.model.eval()  # Set to evaluation mode
        logger.info(f"Successfully loaded {model_name} from {model_path}")
        return wrapper
    except Exception as e:
        logger.info(f"Failed to load {model_name}: {e}")
        return None


def evaluate_ensemble(method='average', test_loader=None):
    logger.info(f"Evaluating ensemble using {method} method...")

    # Load the models using the wrappers
    try:
        # Use the helper function to load models
        mobilenet_wrapper = load_model_wrapper(
            MobileNetV3Wrapper, "MobileNetV3-Small")
        squeezenet_wrapper = load_model_wrapper(
            SqueezeNetWrapper, "SqueezeNet1_1")
        shufflenet_wrapper = load_model_wrapper(
            ShuffleNetV2Wrapper, "ShuffleNetV2_x0_5")

        models = []
        initial_weights = []

        # Add models that were successfully loaded
        if mobilenet_wrapper and squeezenet_wrapper:
            models.extend([mobilenet_wrapper.model, squeezenet_wrapper.model])
            initial_weights.extend([0.4, 0.3])

            # Try to add ShuffleNet if available
            if shufflenet_wrapper:
                models.append(shufflenet_wrapper.model)
                initial_weights.append(0.3)
                logger.info(
                    f"Using 3-model ensemble ({method} method) with MobileNet, SqueezeNet, ShuffleNet")
            else:
                # Rebalance weights for 2-model ensemble
                initial_weights = [0.6, 0.4]
                logger.info("Using 2-model ensemble (MobileNet, SqueezeNet)")
        else:
            raise ValueError("Could not load enough models for ensemble")

        # Create enhanced ensemble with the available models
        ensemble = EnhancedEnsembleModel(
            models, weights=initial_weights, method=method)

    except Exception as e:
        logger.info(f"Could not load models: {e}. Cannot create ensemble.")
        return None

    # Move ensemble to device
    ensemble.to(DEVICE)
    ensemble.eval()

    # Load test data if not provided
    if test_loader is None:
        _, _, test_dataset = get_augmented_quickdraw_data(
            QUICKDRAW_CATEGORIES,
            NUM_TRAIN_SAMPLES_PER_CATEGORY,
            NUM_TEST_SAMPLES_PER_CATEGORY,
            BINARY_DATA_ROOT
        )

        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True,
                                 prefetch_factor=3)

    # If using stacking method, train the meta-classifier first
    if method == 'stack':
        logger.info("Training meta-classifier for stacking ensemble...")
        # Create a small dataset for training the meta-classifier
        _, val_dataset, _ = get_augmented_quickdraw_data(
            QUICKDRAW_CATEGORIES,
            NUM_TRAIN_SAMPLES_PER_CATEGORY // 5,  # Use a smaller subset
            NUM_TEST_SAMPLES_PER_CATEGORY // 5,
            BINARY_DATA_ROOT
        )

        val_loader = DataLoader(
            val_dataset, batch_size=64, shuffle=True, num_workers=2,
            prefetch_factor=3)

        # Freeze base models
        for model in ensemble.models:
            for param in model.parameters():
                param.requires_grad = False

        # Only train meta-classifier
        optimizer = optim.Adam(ensemble.meta_classifier.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        # Train for a few epochs
        ensemble.train()
        for epoch in range(5):
            running_loss = 0.0
            correct = 0
            total = 0

            for images, labels in tqdm(val_loader, desc=f"Meta-classifier epoch {epoch+1}/5"):
                images, labels = images.to(DEVICE), labels.to(DEVICE)

                optimizer.zero_grad()
                outputs = ensemble(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

            epoch_loss = running_loss / total
            epoch_acc = 100 * correct / total
            logger.info(f"Meta-classifier Epoch {epoch+1}/5 - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%")

        ensemble.eval()
        clear_gpu_memory()  # Add explicit memory cleanup after training

    # For the averaging method, train the weights
    elif method == 'average':
        logger.info("Training ensemble weights for averaging method...")
        ensemble = train_ensemble_weights(ensemble, epochs=3)
        clear_gpu_memory()  # Add explicit memory cleanup after training

    # Evaluate
    model_correct = {
        "mobilenet": 0,
        "squeezenet": 0,
        "shufflenet": 0 if len(models) > 2 else None,
        "ensemble": 0
    }
    total = 0

    logger.info("Evaluating ensemble vs individual models...")
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating models"):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            # Get predictions from individual models
            mobilenet_outputs = mobilenet_wrapper.model(images)
            squeezenet_outputs = squeezenet_wrapper.model(images)

            # Get ensemble prediction
            ensemble_outputs = ensemble(images)

            # Calculate accuracy for each model
            _, mobilenet_preds = torch.max(mobilenet_outputs, 1)
            _, squeezenet_preds = torch.max(squeezenet_outputs, 1)
            _, ensemble_preds = torch.max(ensemble_outputs, 1)

            # Update correct counts
            model_correct["mobilenet"] += (mobilenet_preds ==
                                           labels).sum().item()
            model_correct["squeezenet"] += (squeezenet_preds ==
                                            labels).sum().item()

            # Only evaluate ShuffleNet if it's part of the ensemble
            if len(models) > 2 and shufflenet_wrapper:
                shufflenet_outputs = shufflenet_wrapper.model(images)
                _, shufflenet_preds = torch.max(shufflenet_outputs, 1)
                model_correct["shufflenet"] += (shufflenet_preds ==
                                                labels).sum().item()

            model_correct["ensemble"] += (ensemble_preds ==
                                          labels).sum().item()

            total += labels.size(0)

    # Calculate and logger.info accuracies
    logger.info(f"\n--- Model Accuracy Comparison ({method} method) ---")
    mobilenet_acc = 100 * model_correct["mobilenet"] / total
    logger.info(f"MobileNetV3:  {mobilenet_acc:.2f}%")

    squeezenet_acc = 100 * model_correct["squeezenet"] / total
    logger.info(f"SqueezeNet:   {squeezenet_acc:.2f}%")

    # Only logger.info ShuffleNet accuracy if it's part of the ensemble
    if model_correct["shufflenet"] is not None:
        shufflenet_acc = 100 * model_correct["shufflenet"] / total
        logger.info(f"ShuffleNetV2: {shufflenet_acc:.2f}%")
        best_single = max(mobilenet_acc, squeezenet_acc, shufflenet_acc)
    else:
        best_single = max(mobilenet_acc, squeezenet_acc)

    ensemble_acc = 100 * model_correct["ensemble"] / total
    logger.info(f"Ensemble ({method}): {ensemble_acc:.2f}%")

    # Calculate improvement
    improvement = ensemble_acc - best_single
    logger.info(
        f"\nEnsemble improves accuracy by {improvement:.2f}% over the best single model")

    # For averaging method, logger.info the learned weights
    if method == 'average':
        normalized_weights = F.softmax(
            ensemble.learned_weights, dim=0).cpu().detach().numpy()
        logger.info("\nLearned model weights in ensemble:")
        logger.info(f"MobileNetV3:  {normalized_weights[0]:.4f}")
        logger.info(f"SqueezeNet:   {normalized_weights[1]:.4f}")
        if len(models) > 2:
            logger.info(f"ShuffleNetV2: {normalized_weights[2]:.4f}")

    # Save the ensemble model with error handling
    try:
        ensemble_path = f"{MODEL_SAVE_PATH}/ensemble_model_{method}.pth"
        
        # Use the export function for a fully portable model with error handling
        try:
            export_path = export_ensemble_for_deployment(ensemble, ensemble_path)
            logger.info(f"Ensemble model exported to {export_path}")
        except Exception as e:
            logger.error(f"Failed to export ensemble: {e}")
            # Try a simpler export approach
            torch.save(ensemble.state_dict(), ensemble_path)
            logger.info(f"Saved ensemble state dict to {ensemble_path} (fallback method)")
        
        # Also evaluate class-wise performance
        logger.info("\nEvaluating class-wise performance...")
        class_accuracies = evaluate_model_by_class(
            ensemble, test_loader, QUICKDRAW_CATEGORIES)
        
        # Save class accuracies with compression
        class_acc_path = f"{MODEL_SAVE_PATH}/ensemble_{method}_class_accuracies.json.gz"
        with gzip.open(class_acc_path, 'wt') as f:
            json.dump(class_accuracies, f, indent=2)
        logger.info(f"Class accuracies saved to {class_acc_path}")
    except Exception as e:
        logger.info(f"Could not save ensemble model: {e}")

    # Format results for saving
    ensemble_results = [
        {
            'model_name': 'MobileNetV3',
            'accuracy': mobilenet_acc,
            'ensemble_method': method
        },
        {
            'model_name': 'SqueezeNet',
            'accuracy': squeezenet_acc,
            'ensemble_method': method
        },
        {
            'model_name': f'Ensemble ({method})',
            'accuracy': ensemble_acc,
            'improvement_over_best': improvement,
            'ensemble_method': method
        }
    ]
    
    # Add ShuffleNet if it was used
    if model_correct["shufflenet"] is not None:
        ensemble_results.insert(2, {
            'model_name': 'ShuffleNetV2',
            'accuracy': shufflenet_acc,
            'ensemble_method': method
        })
    
    # Save results
    save_benchmark_results(ensemble_results, benchmark_type=f'ensemble_{method}')
    # Return ensemble for later use
    return ensemble

In [None]:
# Run ensemble evaluation with averaging method
logger.info("="*50)
logger.info("EVALUATING ENSEMBLE WITH AVERAGING METHOD")
logger.info("="*50)
averaging_ensemble = evaluate_ensemble(method='average')

# Run ensemble evaluation with stacking method
logger.info("="*50)
logger.info("EVALUATING ENSEMBLE WITH STACKING METHOD")
logger.info("="*50)
stacking_ensemble = evaluate_ensemble(method='stack')

# Compare the results
logger.info("="*50)
logger.info("ENSEMBLE METHOD COMPARISON")
logger.info("="*50)
logger.info("If both methods ran successfully, you can compare their performance.")
logger.info("The stacking method usually performs better when there are complementary strengths")
logger.info("in the base models, while averaging is more robust to overfitting.")

# Save sample inference images
_, _, test_dataset = get_augmented_quickdraw_data(
    QUICKDRAW_CATEGORIES[:10],  # Using just first 10 categories for faster loading
    10,  # Small number of samples per category
    5,
    BINARY_DATA_ROOT
)

# Create a small test loader
test_loader = DataLoader(test_dataset, batch_size=5, shuffle=True)

# Get a batch of test images
try:
    images, labels = next(iter(test_loader))
    images = images.to(DEVICE)
    
    # Create inference wrapper for the best ensemble
    best_ensemble = averaging_ensemble  # You can change this to stacking_ensemble if it performs better
    if best_ensemble is not None:
        inference_wrapper = ClassNameInferenceWrapper(best_ensemble, IDX_TO_CLASS)
        
        # Make predictions
        predictions = inference_wrapper.predict(images)
        
        # Save sample images
        save_sample_images(
            images, 
            labels, 
            predictions, 
            QUICKDRAW_CATEGORIES,
            save_dir="./sample_ensemble_predictions"
        )
except Exception as e:
    logger.info(f"Error saving sample images: {e}")