In [1]:
!pip install -q condacolab
import condacolab
condacolab.install() # This will restart the kernel

⏬ Downloading https://github.com/jaimergp/miniforge/releases/download/24.11.2-1_colab/Miniforge3-colab-24.11.2-1_colab-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:11
🔁 Restarting kernel...


In [2]:
!conda create -n myenv python=3.9 -y
# condacolab usually activates the base environment by default.
# To activate your specific environment for subsequent ! commands in the session:
# This is tricky because each ! is a new shell.
# condacolab aims to make the *kernel itself* run within a Conda environment.
# Check active environment:
!conda env list
!python --version # Should show the python from the active conda env

Channels:
 - conda-forge
Platform: linux-64
Collecting package metadata (repodata.json): - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | failed

CondaError: KeyboardInterrupt


# conda environments:
#
base                   /usr/local

Python 3.11.11


In [3]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [None]:
!pip install ai-edge-torch

In [None]:
import os
import requests
import urllib.parse

# Part 1: Download QuickDraw Dataset Subsample

# --- Configuration for QuickDraw Download ---
# Please expand this list with up to 50 distinct object categories
# relevant to Indian primary education (Standards 1st to 5th).
# Ensure words are common and visually distinct.
QUICKDRAW_CATEGORIES = [
    'apple', 'cat', 'dog', 'door', 'elephant', 'fish', 'flower', 'grapes',
    'grass', 'house', 'ice cream', 'jail', 'key', 'lion', 'moon', 'nose',
    'pencil', 'rabbit', 'sun', 'tree', 'umbrella', 'van', 'cake', 'airplane',
    'ant', 'banana', 'bed', 'bee', 'bicycle', 'bird', 'book', 'bread', 'bus',
    'elbow', 'ear', 'camera', 'car', 'chair', 'clock', 'cloud', 'hand',
    'computer', 'cookie', 'cow', 'crayon', 'cup', 'eraser', 'carrot', 'drums',
    'eye', 'knife'
]
QUICKDRAW_BASE_URL = "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/"
DOWNLOAD_DIR = "./quickdraw_data/"

def download_quickdraw_subset(categories, base_url, download_dir):
    """
    Downloads .npy files for specified QuickDraw categories.
    """
    if not os.path.exists(download_dir):
        os.makedirs(download_dir)
        print(f"Created directory: {download_dir}")

    print(f"Starting download of {len(categories)} QuickDraw categories...")
    for category in categories:
        # Sanitize category name for URL (replace spaces with %20)
        sanitized_category_name = urllib.parse.quote(category)
        file_name = f"{sanitized_category_name}.npy"
        url = f"{base_url}{file_name}"
        output_path = os.path.join(download_dir, f"{category.replace(' ', '_')}.npy") # Use underscore for local filename

        if os.path.exists(output_path):
            print(f"File for '{category}' already exists at {output_path}. Skipping.")
            continue

        print(f"Downloading '{category}' from {url}...")
        try:
            response = requests.get(url, stream=True)
            response.raise_for_status()  # Raise an exception for HTTP errors (4xx or 5xx)

            with open(output_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            print(f"Successfully downloaded and saved to {output_path}")

        except requests.exceptions.HTTPError as e:
            if e.response.status_code == 404:
                print(f"Error: Category '{category}' not found (404) at {url}. Skipping.")
            else:
                print(f"Error downloading '{category}': {e}. Skipping.")
        except requests.exceptions.RequestException as e:
            print(f"Error downloading '{category}': {e}. Skipping.")
        except Exception as e:
            print(f"An unexpected error occurred for '{category}': {e}. Skipping.")
    print("QuickDraw download process finished.")

# Part 2: Proof-of-Concept (POC) SeResNext to .tflite Conversion using AI Edge Torch

def convert_seresnext_to_tflite():
    """
    Loads a pretrained SeResNext model, converts it to .tflite using AI Edge Torch,
    and optionally verifies the conversion.
    """
    try:
        import torch
        import timm
        import ai_edge_torch
        import numpy as np # For dummy input and TFLite verification
        import tensorflow as tf # For TFLite verification
    except ImportError as e:
        print(f"ImportError: {e}. Please ensure all required libraries are installed.")
        print("Installation commands:")
        print("  pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu")
        print("  pip install timm")
        print("  pip install ai-edge-torch")
        print("  pip install tensorflow numpy")
        return

    print("\nStarting SeResNext to .tflite conversion...")

    # 1. Load Pretrained SeResNext Model
    # We'll use 'seresnext50_32x4d' as a common variant.
    # Its default input size is (3, 224, 224).
    model_name = 'seresnext50_32x4d'
    input_height, input_width = 224, 224

    try:
        print(f"Loading pretrained model: {model_name}")
        model = timm.create_model(model_name, pretrained=True)
        model.eval()  # Set the model to evaluation mode
        print(f"Model {model_name} loaded successfully.")
    except Exception as e:
        print(f"Error loading model with timm: {e}")
        return

    # 2. Prepare Dummy Input
    # Input shape (batch_size, channels, height, width)
    dummy_input = torch.randn(1, 3, input_height, input_width)
    print(f"Prepared dummy input with shape: {dummy_input.shape}")

    # 3. Convert to .tflite with AI Edge Torch
    print("Converting model to AI Edge Torch format...")
    try:
        # The second argument to convert must be a tuple of example inputs.
        edge_model = ai_edge_torch.convert(model, (dummy_input,))
        print("Model converted to Edge format.")

        # Export to TFLite flatbuffer
        tflite_model_path = "./content/seresnext_model.tflite"
        edge_model.export(tflite_model_path)
        print("Model exported to file.")

    except Exception as e:
        print(f"Error during AI Edge Torch conversion: {e}")
        return

    # 5. Verification (Optional but Recommended)
    print("\nVerifying the .tflite model...")
    try:
        interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
        interpreter.allocate_tensors()
        print("TFLite model loaded into interpreter.")

        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()

        print(f"  Input details: {input_details[0]['name']}, Shape: {input_details[0]['shape']}, Dtype: {input_details[0]['dtype']}")
        print(f"  Output details: {output_details[0]['name']}, Shape: {output_details[0]['shape']}, Dtype: {output_details[0]['dtype']}")

        # Prepare input for TFLite model.
        # PyTorch uses NCHW (Batch, Channels, Height, Width).
        # TFLite models converted by ai_edge_torch often expect NHWC (Batch, Height, Width, Channels).
        # Let's check the expected shape.
        expected_shape = input_details[0]['shape']
        input_data_np = dummy_input.numpy()

        if expected_shape.tolist() == [1, input_height, input_width, 3]: # NHWC
            print("TFLite model expects NHWC input. Transposing dummy input.")
            input_data_np = np.transpose(input_data_np, (0, 2, 3, 1)) # NCHW to NHWC
        elif expected_shape.tolist() == [1, 3, input_height, input_width]: # NCHW
            print("TFLite model expects NCHW input. No transpose needed.")
        else:
            print(f"Warning: TFLite input shape {expected_shape} is unexpected. Using original dummy input shape.")
            # May need adjustment if this warning appears.

        # Ensure dtype matches
        if input_details[0]['dtype'] == np.float32 and input_data_np.dtype != np.float32:
            input_data_np = input_data_np.astype(np.float32)

        interpreter.set_tensor(input_details[0]['index'], input_data_np)
        interpreter.invoke()
        output_data_tflite = interpreter.get_tensor(output_details[0]['index'])

        print(f"TFLite model inference successful. Output shape: {output_data_tflite.shape}")
        # print(f"Sample output values (first 5): {output_data_tflite.flatten()[:5]}")

        # For a more thorough check, you could compare PyTorch output with TFLite output
        with torch.no_grad():
            pytorch_output = model(dummy_input)
        pytorch_output_np = pytorch_output.numpy()
        np.testing.assert_allclose(pytorch_output_np, output_data_tflite, rtol=1e-3, atol=1e-5)
        print("PyTorch and TFLite outputs are close (basic check).")

    except Exception as e:
        print(f"Error during TFLite model verification: {e}")

    print("SeResNext to .tflite conversion and verification process finished.")


if __name__ == "__main__":
    # --- Part 1 Execution ---
    # You can adjust the QUICKDRAW_CATEGORIES list above before running.
    # Note: Downloading many categories can take time and disk space.
    # download_quickdraw_subset(QUICKDRAW_CATEGORIES, QUICKDRAW_BASE_URL, DOWNLOAD_DIR)
    print("Part 1 (QuickDraw Download) is commented out by default to prevent accidental large downloads.")
    print("Uncomment the line 'download_quickdraw_subset(...)' in the if __name__ == '__main__': block to run it.")
    print("-" * 50)

    # --- Part 2 Execution ---
    convert_seresnext_to_tflite()

SeResNet-50 using the following configs:
GitHub Copilot
This is a significant update that involves setting up a complete fine-tuning pipeline. Below is a Python script that outlines the major components.

Important Considerations and Assumptions:

Stroke Data for 3-Channel Input:

The QuickDraw .npy files you're downloading (numpy_bitmap/) are 28x28 grayscale bitmaps. They do not contain the raw stroke data required to create channels from 1/3, 2/3, and 3/3 of strokes.
Adaptation Made: To proceed with the downloaded .npy files, this script will resize the 28x28 grayscale image to 224x224 and then replicate this single channel to create a 3-channel image. This is a common workaround when fine-tuning models pre-trained on RGB images (like ImageNet models) with grayscale data.
If you strictly need the stroke-based 3-channel input, you would need to download the .ndjson files (raw stroke data) and implement a more complex preprocessing pipeline to render images from partial strokes.
"Unrecognized" Class:

The prompt mentions "including unrecognized." The QuickDraw dataset doesn't have a generic "unrecognized" class. You would need to select one of the existing 345 classes to serve this purpose (e.g., "The Mona Lisa" is sometimes used) or curate your own "unrecognized" samples. This script will proceed with the 50 specified classes. If you want an "unrecognized" category, add its name to QUICKDRAW_CATEGORIES and ensure its .npy file is downloaded.
MAP@3 for Learning Rate Scheduler:

Calculating Mean Average Precision at 3 (MAP@3) typically requires the model to output a ranked list of predictions and is more common in retrieval or ranking tasks. For a standard classifier outputting logits for 50 classes, it's simpler to use overall validation accuracy or loss.
Adaptation Made: This script will use validation accuracy to monitor for the ReduceLROnPlateau scheduler. If MAP@3 is a hard requirement, the evaluation logic and potentially the model's output layer would need significant adjustments.
Averaging Last Ten Weights & TTA for Inference:

The script will implement saving checkpoints.
The logic for averaging the last ten weights and performing horizontal flip Test-Time Augmentation (TTA) will be outlined conceptually for the inference phase, as a full inference script is extensive. The TFLite conversion will typically be done on a single (potentially weight-averaged) fine-tuned model.
Computational Resources: Fine-tuning seresnext50_32x4d with a batch size of 256 requires a capable GPU with sufficient VRAM.

In [None]:
import os
import requests
import urllib.parse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms
import timm
from sklearn.model_selection import train_test_split # For splitting indices
import random
import glob # For finding checkpoints

# --- Configuration ---
QUICKDRAW_CATEGORIES = [ # Your 50 classes
    'apple', 'ball', 'cat', 'dog', 'elephant',
    'fish', 'grapes', 'house', 'ice cream', 'jellyfish',
    'kite', 'lion', 'moon', 'nose', 'orange',
    'pencil', 'question mark', 'rabbit', 'sun', 'tree',
    'umbrella', 'van', 'watch', 'xylophone', 'yo-yo',
    'zebra', 'airplane', 'ant', 'banana', 'bed',
    'bee', 'bicycle', 'bird', 'boat', 'book',
    'bread', 'bus', 'butterfly', 'cake', 'camera',
    'car', 'chair', 'clock', 'cloud', 'computer',
    'cookie', 'cow', 'crayon', 'cup', 'eraser'
]
NUM_CLASSES = len(QUICKDRAW_CATEGORIES)
QUICKDRAW_BASE_URL = "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/"
DOWNLOAD_DIR = "./quickdraw_data/"
CHECKPOINT_DIR = "./quickdraw_checkpoints/"
MODEL_NAME = 'seresnext50_32x4d'
INPUT_SIZE = 224
BATCH_SIZE = 256 # Adjust if OOM error occurs
LEARNING_RATE = 0.00025
NUM_EPOCHS = 20 # Adjust as needed
VALIDATION_SAMPLES_PER_CLASS = 500
CHECKPOINT_SAVE_STEP = 5000 # Save every 5000 training steps

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

# --- Part 1: Data Download and Preparation ---
def download_quickdraw_subset(categories, base_url, download_dir):
    if not os.path.exists(download_dir):
        os.makedirs(download_dir)
    print(f"Starting download of {len(categories)} QuickDraw categories...")
    for category in categories:
        sanitized_category_name = urllib.parse.quote(category)
        file_name = f"{sanitized_category_name}.npy"
        url = f"{base_url}{file_name}"
        output_path = os.path.join(download_dir, f"{category.replace(' ', '_')}.npy")
        if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
            print(f"File for '{category}' already exists. Skipping.")
            continue
        print(f"Downloading '{category}' from {url}...")
        try:
            response = requests.get(url, stream=True)
            response.raise_for_status()
            with open(output_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192): f.write(chunk)
            print(f"Successfully downloaded and saved to {output_path}")
        except Exception as e:
            print(f"Error downloading '{category}': {e}. Skipping.")
    print("QuickDraw download process finished.")

class QuickDrawBitmapDataset(Dataset):
    def __init__(self, image_paths, labels, class_to_idx, transform=None):
        self.image_paths = image_paths # List of (filepath, index_in_npy)
        self.labels = labels         # List of integer labels
        self.class_to_idx = class_to_idx
        self.transform = transform
        self._loaded_npy_files = {} # Cache for loaded npy files

    def __len__(self):
        return len(self.image_paths)

    def _load_image(self, filepath, index_in_npy):
        if filepath not in self._loaded_npy_files:
            try:
                self._loaded_npy_files[filepath] = np.load(filepath, mmap_mode='r')
            except Exception as e:
                print(f"Error loading {filepath}: {e}")
                return None # Or raise error

        try:
            # Images are 784 (28*28) flat arrays, normalized to 0-1
            img = self._loaded_npy_files[filepath][index_in_npy].astype(np.float32)
            img = img.reshape(28, 28)
        except IndexError:
            print(f"Index {index_in_npy} out of bounds for {filepath} (len {len(self._loaded_npy_files[filepath])}). Skipping sample.")
            return None # Skip this sample
        except Exception as e:
            print(f"Error processing image {index_in_npy} from {filepath}: {e}")
            return None

        # Create 3-channel image by resizing and replicating the grayscale channel
        # This adapts the 28x28 grayscale to the 224x224x3 expected by the model
        img_tensor = torch.from_numpy(img).unsqueeze(0) # (1, 28, 28)

        # Resize to INPUT_SIZE x INPUT_SIZE
        resize_transform = transforms.Resize((INPUT_SIZE, INPUT_SIZE), antialias=True)
        img_resized = resize_transform(img_tensor) # (1, INPUT_SIZE, INPUT_SIZE)

        # Replicate to 3 channels
        img_3_channel = img_resized.repeat(3, 1, 1) # (3, INPUT_SIZE, INPUT_SIZE)

        # Normalize if necessary (timm models usually handle this)
        # For ImageNet, normalization is typically:
        # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        # img_3_channel = normalize(img_3_channel)
        # However, timm models often have built-in normalization or expect 0-1 input.
        # For QuickDraw, images are already 0-1. Replicating keeps them 0-1.

        if self.transform: # For TTA later, not used in training as per "no augmentation"
            img_3_channel = self.transform(img_3_channel)

        return img_3_channel

    def __getitem__(self, idx):
        filepath, index_in_npy = self.image_paths[idx]
        label = self.labels[idx]

        image = self._load_image(filepath, index_in_npy)
        if image is None: # Handle cases where image loading failed
            # Return a dummy sample or skip. For simplicity, we might need a collate_fn to handle Nones
            # Or ensure _load_image always returns a valid tensor (e.g. a black image)
            print(f"Warning: Could not load image for idx {idx}, path {filepath}, npy_idx {index_in_npy}. Returning dummy data.")
            return torch.zeros((3, INPUT_SIZE, INPUT_SIZE)), torch.tensor(0) # Dummy data

        return image, torch.tensor(label, dtype=torch.long)

def prepare_dataloaders(categories, download_dir, val_samples_per_class, batch_size):
    class_to_idx = {name: i for i, name in enumerate(categories)}
    idx_to_class = {i: name for i, name in enumerate(categories)}

    all_image_paths = [] # List of (filepath, index_in_npy)
    all_labels = []

    for class_name in categories:
        npy_path = os.path.join(download_dir, f"{class_name.replace(' ', '_')}.npy")
        if not os.path.exists(npy_path):
            print(f"Warning: NPY file not found for {class_name} at {npy_path}. Skipping this class.")
            continue
        try:
            # Load just to get the count, actual loading in Dataset
            data = np.load(npy_path, mmap_mode='r')
            num_images_in_class = data.shape[0]
            del data # Free memory

            indices = list(range(num_images_in_class))
            random.shuffle(indices)

            val_indices_class = indices[:val_samples_per_class]
            train_indices_class = indices[val_samples_per_class:]

            for i in train_indices_class:
                all_image_paths.append((npy_path, i))
                all_labels.append(class_to_idx[class_name])
            for i in val_indices_class:
                all_image_paths.append((npy_path, i)) # Add to all_image_paths for splitting later
                all_labels.append(class_to_idx[class_name])
        except Exception as e:
            print(f"Error processing class {class_name} from {npy_path}: {e}")
            continue

    if not all_image_paths:
        print("No data loaded. Exiting.")
        return None, None, None

    # Split all_image_paths and all_labels into train and val sets
    # This ensures that validation samples are strictly from the designated val_indices_class
    train_img_paths, val_img_paths, train_labels, val_labels = [], [], [], []

    temp_class_val_counts = {cls_idx: 0 for cls_idx in class_to_idx.values()}

    # Create a combined list of (path_tuple, label) to shuffle for train/val split
    # This is a bit complex because we pre-selected val_samples_per_class *from each class*
    # A simpler way is to iterate through classes, assign val, then train.

    # Let's rebuild train/val lists directly
    for class_name in categories:
        class_idx = class_to_idx[class_name]
        npy_path = os.path.join(download_dir, f"{class_name.replace(' ', '_')}.npy")
        if not os.path.exists(npy_path): continue

        try:
            data_len = np.load(npy_path, mmap_mode='r').shape[0]
            indices = list(range(data_len))
            random.shuffle(indices)

            current_val_indices = indices[:val_samples_per_class]
            current_train_indices = indices[val_samples_per_class:]

            for i in current_train_indices:
                train_img_paths.append((npy_path, i))
                train_labels.append(class_idx)
            for i in current_val_indices:
                val_img_paths.append((npy_path, i))
                val_labels.append(class_idx)
        except Exception as e:
            print(f"Error during train/val split for {class_name}: {e}")


    print(f"Total training samples: {len(train_img_paths)}")
    print(f"Total validation samples: {len(val_img_paths)}")

    if not train_img_paths or not val_img_paths:
        print("Not enough data for training or validation split. Exiting.")
        return None, None, None

    train_dataset = QuickDrawBitmapDataset(train_img_paths, train_labels, class_to_idx)
    val_dataset = QuickDrawBitmapDataset(val_img_paths, val_labels, class_to_idx)

    # Custom collate_fn to filter out None values if image loading fails
    def collate_fn_skip_none(batch):
        batch = list(filter(lambda x: x[0] is not None, batch))
        if not batch: # If all samples in batch failed
            return torch.empty(0), torch.empty(0)
        return torch.utils.data.dataloader.default_collate(batch)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn_skip_none)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn_skip_none)

    return train_loader, val_loader, class_to_idx

# --- Part 2: Model Definition ---
def get_model(num_classes, model_name=MODEL_NAME, pretrained=True):
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
    # Example: For seresnext50_32x4d, the classifier is model.fc
    # For other models, it might be model.classifier or model.head
    # print(model) # To inspect the model structure and find the classifier layer name
    # Assuming it's 'fc' for seresnext50_32x4d
    # if hasattr(model, 'fc') and isinstance(model.fc, nn.Linear):
    #     num_ftrs = model.fc.in_features
    #     model.fc = nn.Linear(num_ftrs, num_classes)
    # elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Linear):
    #     num_ftrs = model.classifier.in_features
    #     model.classifier = nn.Linear(num_ftrs, num_classes)
    # else:
    #     print("Warning: Could not automatically find and replace the classifier layer.")
    # timm's create_model with num_classes argument usually handles this.
    return model.to(DEVICE)

# --- Part 3: Training Loop ---
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, checkpoint_dir, checkpoint_save_step):
    best_val_acc = 0.0
    global_step = 0
    recent_checkpoints = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for i, (inputs, labels) in enumerate(train_loader):
            if inputs.numel() == 0: # Skip if batch is empty due to collate_fn
                continue
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

            global_step += 1
            if global_step % checkpoint_save_step == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f"model_step_{global_step}.pth")
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")
                recent_checkpoints.append(checkpoint_path)
                if len(recent_checkpoints) > 10: # Keep only the last 10
                    oldest_ckpt = recent_checkpoints.pop(0)
                    if os.path.exists(oldest_ckpt):
                        # os.remove(oldest_ckpt) # Optional: remove older checkpoints
                        pass


        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc_train = correct_train / total_train if total_train > 0 else 0

        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                if inputs.numel() == 0: continue
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        epoch_val_loss = val_loss / len(val_loader.dataset) if len(val_loader.dataset) > 0 else 0
        epoch_val_acc = correct_val / total_val if total_val > 0 else 0

        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {epoch_loss:.4f} Acc: {epoch_acc_train:.4f} | "
              f"Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f}")

        if scheduler:
            scheduler.step(epoch_val_acc) # ReduceLROnPlateau expects a metric

        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_model_path = os.path.join(checkpoint_dir, "best_model.pth")
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved to {best_model_path} with Val Acc: {best_val_acc:.4f}")

    print("Training finished.")
    return recent_checkpoints # Return list of last 10 checkpoints

# --- Part 4: Inference Concepts (Not a full script) ---
def average_last_k_weights(checkpoint_paths):
    if not checkpoint_paths:
        print("No checkpoints provided for averaging.")
        return None

    print(f"Averaging weights from: {checkpoint_paths}")
    avg_state_dict = None

    # Load the first checkpoint to initialize avg_state_dict
    try:
        base_state_dict = torch.load(checkpoint_paths[0], map_location=DEVICE)
        avg_state_dict = {k: v.clone().to(dtype=torch.float32) for k, v in base_state_dict.items()} # Ensure float for accumulation
    except Exception as e:
        print(f"Error loading base checkpoint {checkpoint_paths[0]}: {e}")
        return None

    # Accumulate weights from subsequent checkpoints
    for i in range(1, len(checkpoint_paths)):
        try:
            state_dict = torch.load(checkpoint_paths[i], map_location=DEVICE)
            for k in avg_state_dict.keys():
                avg_state_dict[k] += state_dict[k].to(dtype=torch.float32)
        except Exception as e:
            print(f"Error loading or adding checkpoint {checkpoint_paths[i]}: {e}. Skipping this checkpoint.")
            continue # Skip this checkpoint if it causes an error

    # Average the accumulated weights
    num_valid_checkpoints = len(checkpoint_paths) # Adjust if some were skipped due to errors
    for k in avg_state_dict.keys():
        avg_state_dict[k] /= num_valid_checkpoints

    # Create a new model and load the averaged weights
    averaged_model = get_model(NUM_CLASSES, pretrained=False) # Don't load timm pretrained weights
    averaged_model.load_state_dict(avg_state_dict)
    averaged_model.to(DEVICE)
    averaged_model.eval()
    print("Weight averaging complete.")
    return averaged_model

def predict_with_tta(model, image_tensor): # image_tensor: (B, C, H, W)
    model.eval()
    image_tensor = image_tensor.to(DEVICE)

    # Original prediction
    with torch.no_grad():
        outputs_original = model(image_tensor)
        probs_original = torch.softmax(outputs_original, dim=1)

    # Horizontal flip
    flipped_image = transforms.functional.hflip(image_tensor)
    with torch.no_grad():
        outputs_flipped = model(flipped_image)
        probs_flipped = torch.softmax(outputs_flipped, dim=1)

    # Average probabilities (weight 0.5 for original, 0.5 for flipped)
    # The prompt "weight 0.5" for TTA is a bit ambiguous.
    # Common practice is (probs_original + probs_flipped) / 2
    # Or if it means weight for the TTA part: (1-w)*orig + w*tta_aug
    # Assuming equal weighting:
    avg_probs = (probs_original + probs_flipped) / 2.0

    return avg_probs


# --- Part 5: TFLite Conversion (Can be run after fine-tuning) ---
# (This part is similar to your previous script, adapted for the fine-tuned model)
def convert_fine_tuned_to_tflite(model_path, num_classes, output_tflite_path="fine_tuned_seresnext.tflite"):
    try:
        import ai_edge_torch # Import locally
    except ImportError:
        print("ai_edge_torch not installed. Skipping TFLite conversion.")
        print("Install with: pip install ai-edge-torch")
        return

    print(f"\nConverting fine-tuned model {model_path} to TFLite...")
    model = get_model(num_classes, pretrained=False) # Load architecture, not pretrained weights

    try:
        # Load the fine-tuned weights
        state_dict = torch.load(model_path, map_location=DEVICE)
        # Handle potential DataParallel prefix 'module.'
        if all(key.startswith('module.') for key in state_dict.keys()):
            state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)
        model.eval()
        model.to(DEVICE) # Ensure model is on the correct device before creating dummy input
        print("Fine-tuned model loaded successfully.")
    except Exception as e:
        print(f"Error loading fine-tuned model weights from {model_path}: {e}")
        return

    # Prepare Dummy Input matching model's device
    dummy_input = torch.randn(1, 3, INPUT_SIZE, INPUT_SIZE).to(DEVICE)
    sample_inputs = (dummy_input,)

    print("Converting model and exporting to TFLite...")
    try:
        edge_model = ai_edge_torch.convert(model, sample_inputs)
        print("Model converted to Edge format.")
        edge_model.export(output_tflite_path)
        print(f"TFLite model saved to {output_tflite_path}")
    except Exception as e:
        print(f"Error during AI Edge Torch conversion or export: {e}")

# --- Main Execution ---
if __name__ == "__main__":
    # 1. Download data
    print("--- Starting Data Download ---")
    download_quickdraw_subset(QUICKDRAW_CATEGORIES, QUICKDRAW_BASE_URL, DOWNLOAD_DIR)

    # 2. Prepare Dataloaders
    print("\n--- Preparing Dataloaders ---")
    train_loader, val_loader, class_to_idx = prepare_dataloaders(
        QUICKDRAW_CATEGORIES, DOWNLOAD_DIR, VALIDATION_SAMPLES_PER_CLASS, BATCH_SIZE
    )

    if train_loader and val_loader:
        # 3. Get Model
        print("\n--- Initializing Model ---")
        model = get_model(NUM_CLASSES, pretrained=True)

        # 4. Define Loss, Optimizer, Scheduler
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
        # Reduce LR when validation accuracy has stopped improving
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)

        # 5. Train Model
        print("\n--- Starting Training ---")
        # Ensure checkpoint directory exists
        if not os.path.exists(CHECKPOINT_DIR):
            os.makedirs(CHECKPOINT_DIR)

        recent_checkpoints = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, NUM_EPOCHS, CHECKPOINT_DIR, CHECKPOINT_SAVE_STEP)

        # 6. (Optional) Average last K weights after training
        if recent_checkpoints:
            print("\n--- Averaging Last K Weights (Example) ---")
            # For actual use, you might want to select the last 10 from all saved during training
            # The `recent_checkpoints` list from `train_model` holds the last 10 saved due to step count.
            # If you saved a 'best_model.pth', you might use that directly or average around it.

            # Example: average checkpoints saved by step
            # To get *all* step checkpoints:
            all_step_checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "model_step_*.pth")))
            last_10_step_checkpoints = all_step_checkpoints[-10:]

            if last_10_step_checkpoints:
                 averaged_model = average_last_k_weights(last_10_step_checkpoints)
                 if averaged_model:
                     torch.save(averaged_model.state_dict(), os.path.join(CHECKPOINT_DIR, "averaged_model.pth"))
                     print(f"Averaged model saved to {os.path.join(CHECKPOINT_DIR, 'averaged_model.pth')}")
                     # Path for TFLite conversion could be this averaged model
                     path_for_tflite = os.path.join(CHECKPOINT_DIR, "averaged_model.pth")
            else:
                print("Not enough step checkpoints to average. Using best_model.pth for TFLite conversion if available.")
                path_for_tflite = os.path.join(CHECKPOINT_DIR, "best_model.pth")
        else:
             path_for_tflite = os.path.join(CHECKPOINT_DIR, "best_model.pth")


        # 7. Convert the fine-tuned model (e.g., best or averaged) to TFLite
        if os.path.exists(path_for_tflite):
            convert_fine_tuned_to_tflite(path_for_tflite, NUM_CLASSES)
        else:
            print(f"Model file {path_for_tflite} not found for TFLite conversion.")

    else:
        print("Dataloaders could not be prepared. Exiting fine-tuning process.")

    print("\n--- Script Finished ---")
