In [None]:
!pip install ofa

Collecting ofa
  Downloading ofa-0.1.0.post202307202001-py3-none-any.whl.metadata (1.4 kB)
Downloading ofa-0.1.0.post202307202001-py3-none-any.whl (107 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.6/107.6 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ofa
Successfully installed ofa-0.1.0.post202307202001


In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset # <-- For calibration
import random # <-- For calibration subset
from PIL import Image
import urllib.request
import json
import time
import sys
import os
import requests
import tarfile  # For .tgz files
import shutil   # For deleting a corrupt zip/dir
# import pandas as pd # <-- No longer needed

try:
    from ofa.model_zoo import ofa_net
except ImportError as e:
    print("Error: 'ofa' library not found or a component is missing.")
    print(f"Import error details: {e}")
    print("Please run: pip install ofa")
    sys.exit(1)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")

Using device: cuda


In [None]:
# --- 1. SETUP & DATA LOADING ---

# Download ImageNet 1000-class labels (we need this for the mapping)
print("Downloading ImageNet 1000-class index...")
label_url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
class_idx = json.load(urllib.request.urlopen(label_url))
imagenet_labels = [class_idx[str(k)][1] for k in range(len(class_idx))]
# Create a WordNet ID -> 1000-class index mapping
# e.g., 'n02110185' -> 258 (Samoyed)
wordnet_to_1000_idx = {v[0]: int(k) for k, v in class_idx.items()}
print("ImageNet 1000-class index loaded.")

def download_and_untar_imagenette():
    """
    Downloads and unzips the Imagenette dataset if not already present.
    """
    dataset_dir = "imagenette2-160"
    tgz_path = f"{dataset_dir}.tgz"
    url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz"

    if os.path.exists(dataset_dir):
        print("Imagenette directory already exists.")
        return dataset_dir

    if not os.path.exists(tgz_path):
        print(f"Downloading Imagenette (150MB) from {url}...")
        print("This may take a few minutes...")
        try:
            with requests.get(url, stream=True) as r:
                r.raise_for_status()
                with open(tgz_path, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            print("Download complete.")
        except Exception as e:
            print(f"Error downloading Imagenette: {e}")
            if os.path.exists(tgz_path):
                os.remove(tgz_path) # Clean up partial download
            sys.exit(1)

    print(f"Un-tarring {tgz_path}...")
    try:
        with tarfile.open(tgz_path, "r:gz") as tar_ref:
            tar_ref.extractall()
        print(f"Successfully un-tarred to {dataset_dir}.")
        # os.remove(tgz_path) # Optional: remove tgz file after extraction
    except Exception as e:
        print(f"Error un-tarring file: {e}. Please delete it and re-run.")
        shutil.rmtree(dataset_dir, ignore_errors=True) # Clean up bad extraction
        os.remove(tgz_path)
        sys.exit(1)

    return dataset_dir

# Standard ImageNet transforms
def get_transforms():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

# --- NEW DATA LOADING FUNCTION ---
def load_imagenette_val_set(dataset_dir, wordnet_to_1000_idx_map, max_images=100):
    """
    Loads a subset of the Imagenette validation set using ImageFolder
    and maps its 0-9 classes to the correct 1000-class ImageNet index.
    """
    print(f"Loading {max_images} images from Imagenette validation set...")

    transform = get_transforms()

    val_dir = os.path.join(dataset_dir, 'val')
    if not os.path.exists(val_dir):
        print(f"Error: Validation directory not found at {val_dir}")
        return []

    # Load the dataset using ImageFolder
    val_dataset = ImageFolder(root=val_dir, transform=transform)

    # Create the mapping from ImageFolder's 0-9 index to the 1000-class index
    # val_dataset.class_to_idx is {'n01440764': 0, 'n02102040': 1, ...}
    imgfolder_idx_to_1000_idx = {}
    for wordnet_id, imgfolder_idx in val_dataset.class_to_idx.items():
        if wordnet_id in wordnet_to_1000_idx_map:
            idx_1000 = wordnet_to_1000_idx_map[wordnet_id]
            imgfolder_idx_to_1000_idx[imgfolder_idx] = idx_1000
        else:
            # This class isn't in the 1000-class set (shouldn't happen for Imagenette)
            imgfolder_idx_to_1000_idx[imgfolder_idx] = -1 # Flag to skip

    print("ImageFolder mapping to 1000-class index created.")

    val_set = []
    # Load a subset of images
    for i in range(len(val_dataset)):
        if i >= max_images:
            break

        try:
            # Get the pre-processed image and its 0-9 index
            input_tensor, imgfolder_idx = val_dataset[i]

            # Translate to the 1000-class index
            true_label_1000_idx = imgfolder_idx_to_1000_idx.get(imgfolder_idx, -1)

            if true_label_1000_idx != -1:
                true_label_name = imagenet_labels[true_label_1000_idx]
                # Add the batch dimension
                val_set.append((input_tensor.unsqueeze(0).to(DEVICE), true_label_1000_idx, true_label_name))
        except Exception as e:
            print(f"Warning: Could not process image at index {i}. Skipping. Error: {e}")

    print(f"Val set ready with {len(val_set)} images.")
    return val_set

# --- NEW: Function to load training data for BN calibration ---
def load_imagenette_train_loader(dataset_dir, batch_size=64, num_samples=2000):
    """
    Loads a subset of the Imagenette training set for BN calibration.
    """
    print(f"Loading calibration data from Imagenette train set...")
    transform = get_transforms()
    train_dir = os.path.join(dataset_dir, 'train')

    if not os.path.exists(train_dir):
        print(f"Error: Training directory not found at {train_dir}")
        return None

    train_dataset = ImageFolder(root=train_dir, transform=transform)

    # Create a random subset for faster calibration
    num_total_samples = len(train_dataset)
    subset_indices = random.sample(range(num_total_samples), min(num_samples, num_total_samples))
    calibration_subset = Subset(train_dataset, subset_indices)

    train_loader = DataLoader(
        calibration_subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )
    print(f"Calibration data loader ready with {len(calibration_subset)} images.")
    return train_loader

Downloading ImageNet 1000-class index...
ImageNet 1000-class index loaded.


In [None]:
# --- 2. EVALUATION FUNCTIONS ---

# --- NEW: Function to calibrate Batch Norm stats ---
def calibrate_model_bn(model, train_loader, model_name):
    """
    Recalculates the running mean and variance for the Batch Norm layers
    of a sampled sub-network.
    """
    print(f"Calibrating Batch Norm for {model_name}...")
    model.to(DEVICE)

    # Set model to train() mode to update BN stats
    model.train()

    with torch.no_grad():
        for i, (images, _) in enumerate(train_loader):
            images = images.to(DEVICE)
            # Just run the forward pass to update BN stats
            model(images)
            if i > 100: # Calibrate on ~100 batches
                break

    # Set model back to eval() mode for inference
    model.eval()
    print(f"Calibration complete for {model_name}.")


# Approximate FLOPs calculation
def get_flops(model, input_shape=(1, 3, 224, 224)):
    model.eval()
    input_tensor = torch.randn(input_shape).to(DEVICE)
    flops = 0
    def hook(module, input, output):
        nonlocal flops
        # This is a rough approximation (counts MACs)
        if isinstance(module, torch.nn.Conv2d):
            # output H * W * C_out * (C_in * K_h * K_w)
            flops += int(output.nelement() * module.in_channels * module.kernel_size[0] * module.kernel_size[1] / module.groups)
        elif isinstance(module, torch.nn.Linear):
            # output N * C_out * (C_in)
            flops += int(output.nelement() * module.in_features)

    handles = [m.register_forward_hook(hook) for m in model.modules()]
    with torch.no_grad():
        model(input_tensor)
    for h in handles:
        h.remove()
    return f"{flops / 1e9:.2f} G" # Return Giga-FLOPs

# Updated evaluation with FLOPs and more runs
def run_evaluation(model, val_set, model_name, n_runs_per_image=1):
    model.eval() # Ensure model is in eval mode for evaluation
    model.to(DEVICE)

    params_m = sum(p.numel() for p in model.parameters()) / 1e6

    # Run FLOPs calculation once
    try:
        flops = get_flops(model)
    except Exception as e:
        print(f"Warning: Could not calculate FLOPs for {model_name}. Error: {e}")
        flops = "N/A"

    correct_predictions = 0
    total_images = len(val_set)
    total_latency_ms = 0

    print(f"\n--- Evaluating {model_name} (avg over {n_runs_per_image} runs per image) ---")

    with torch.no_grad():
        # Extended warm-up
        if val_set:
            print("Warming up...")
            for _ in range(3):
                _ = model(val_set[0][0])

        print("Starting evaluation...")
        for i, (input_tensor, true_label_idx, true_label_name) in enumerate(val_set):
            current_image_time_ms = 0
            output = None
            for _ in range(n_runs_per_image):
                start_time = time.time()
                # input_tensor is already pre-processed and on the correct device
                output = model(input_tensor)
                end_time = time.time()
                current_image_time_ms += (end_time - start_time) * 1000

            avg_time_for_this_image = current_image_time_ms / n_runs_per_image
            total_latency_ms += avg_time_for_this_image

            probabilities = torch.nn.functional.softmax(output[0], dim=0)
            top1_prob, top1_catid = torch.topk(probabilities, 1)

            prediction_idx = top1_catid.item()
            prediction_name = imagenet_labels[prediction_idx]

            # Simple progress print
            if (i+1) % 20 == 0:
                print(f"  Processed {i+1}/{total_images} images...")

            if prediction_idx == true_label_idx:
                correct_predictions += 1
                # print(f" [✓] Correct: '{true_label_name}' (Latency: {avg_time_for_this_image:.2f} ms)")
            # else:
                # print(f" [X] WRONG: Predicted '{prediction_name}', True label was '{true_label_name}' (Latency: {avg_time_for_this_image:.2f} ms)")

    if total_images == 0:
        print("Error: No images were loaded in the validation set. Cannot evaluate.")
        return {"name": model_name, "params_m": f"{params_m:.2f} M", "flops": flops, "accuracy": "N/A", "latency": "N/A"}

    accuracy = (correct_predictions / total_images) * 100
    avg_latency = total_latency_ms / total_images

    print(f"\nResult: {correct_predictions} / {total_images} correct ({accuracy:.1f}%)")
    print(f"Efficiency (Size): {params_m:.2f} Million Parameters")
    print(f"FLOPs (Approx): {flops}")
    print(f"Avg Latency per Image: {avg_latency:.2f} ms ({'(GPU)' if DEVICE == 'cuda' else '(CPU)'})")

    return {"name": model_name, "params_m": f"{params_m:.2f} M", "flops": flops, "accuracy": f"{accuracy:.1f}%", "latency": f"{avg_latency:.2f} ms"}


In [None]:
# --- 3. MAIN EXECUTION ---

# Prepare Imagenette
dataset_dir = download_and_untar_imagenette()
# We pass the wordnet_to_1000_idx map to the loading function
val_set = load_imagenette_val_set(dataset_dir, wordnet_to_1000_idx, max_images=100)
# --- NEW: Load calibration data ---
train_loader = load_imagenette_train_loader(dataset_dir)


if not val_set or not train_loader:
    print("Error: Validation or Training set is empty. Cannot proceed.")
    sys.exit(1)

# Load supernet
print("\nLoading OFA MobileNetV3 super-network...")
ofa_super_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.2', pretrained=True)
print("Super-network loaded.")

# Sample and evaluate subnetworks
results = []

# Tiny spec (unchanged)
tiny_spec = {'d': [2]*5, 'e': [3]*20, 'k': [3]*20}
print("\nSampling 'Tiny Model'...")
ofa_super_network.set_active_subnet(ks=tiny_spec['k'], e=tiny_spec['e'], d=tiny_spec['d'])
tiny_model = ofa_super_network.get_active_subnet(preserve_weight=True)
# --- NEW: Calibrate BN ---
calibrate_model_bn(tiny_model, train_loader, "Tiny Model")
results.append(run_evaluation(tiny_model, val_set, "Tiny Model (Target: Microcontroller)"))

# New: Medium spec
medium_spec = {'d': [3]*5, 'e': [4]*10 + [5]*10, 'k': [5]*20}
print("\nSampling 'Medium Model'...")
ofa_super_network.set_active_subnet(ks=medium_spec['k'], e=medium_spec['e'], d=medium_spec['d'])
medium_model = ofa_super_network.get_active_subnet(preserve_weight=True)
# --- NEW: Calibrate BN ---
calibrate_model_bn(medium_model, train_loader, "Medium Model")
results.append(run_evaluation(medium_model, val_set, "Medium Model (Target: Edge Device)"))

# Large spec (unchanged)
large_spec = {'d': [4]*5, 'e': [6]*20, 'k': [7]*20}
print("\nSampling 'Large Model'...")
ofa_super_network.set_active_subnet(ks=large_spec['k'], e=large_spec['e'], d=large_spec['d'])
large_model = ofa_super_network.get_active_subnet(preserve_weight=True)
# --- NEW: Calibrate BN ---
calibrate_model_bn(large_model, train_loader, "Large Model")
results.append(run_evaluation(large_model, val_set, "Large Model (Target: Smartphone)"))

# Final comparison with FLOPs
print("\n\n--- Demo Summary ---")
print(f"Evaluated on a {len(val_set)}-image subset of the Imagenette validation set.")
print("All models sampled from the SAME pre-trained super-network, with NO re-training.")
print("-" * 100)
print(f"| {'Model':<30} | {'Params (Size)':<15} | {'FLOPs (Approx)':<15} | {'Avg Latency':<15} | {'Accuracy':<15} |")
print(f"| {'-'*30} | {'-'*15} | {'-'*15} | {'-'*15} | {'-'*15} |")
for res in results:
    print(f"| {res['name']:<30} | {res['params_m']:<15} | {res['flops']:<15} | {res['latency']:<15} | {res['accuracy']:<15} |")
print("-" * 100)


Downloading Imagenette (150MB) from https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz...
This may take a few minutes...
Download complete.
Un-tarring imagenette2-160.tgz...


  tar_ref.extractall()


Successfully un-tarred to imagenette2-160.
Loading 100 images from Imagenette validation set...
ImageFolder mapping to 1000-class index created.
Val set ready with 100 images.
Loading calibration data from Imagenette train set...
Calibration data loader ready with 2000 images.

Loading OFA MobileNetV3 super-network...


Downloading: "https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_nets/ofa_mbv3_d234_e346_k357_w1.2" to .torch/ofa_nets/ofa_mbv3_d234_e346_k357_w1.2


Super-network loaded.

Sampling 'Tiny Model'...
Calibrating Batch Norm for Tiny Model...
Calibration complete for Tiny Model.

--- Evaluating Tiny Model (Target: Microcontroller) (avg over 1 runs per image) ---
Warming up...
Starting evaluation...
  Processed 20/100 images...
  Processed 40/100 images...
  Processed 60/100 images...
  Processed 80/100 images...
  Processed 100/100 images...

Result: 75 / 100 correct (75.0%)
Efficiency (Size): 4.60 Million Parameters
FLOPs (Approx): 0.19 G
Avg Latency per Image: 5.60 ms ((GPU))

Sampling 'Medium Model'...
Calibrating Batch Norm for Medium Model...
Calibration complete for Medium Model.

--- Evaluating Medium Model (Target: Edge Device) (avg over 1 runs per image) ---
Warming up...
Starting evaluation...
  Processed 20/100 images...
  Processed 40/100 images...
  Processed 60/100 images...
  Processed 80/100 images...
  Processed 100/100 images...

Result: 86 / 100 correct (86.0%)
Efficiency (Size): 7.20 Million Parameters
FLOPs (Approx)