<a href="https://colab.research.google.com/github/RH00000/UH_RTS_Research_ML/blob/main/skipping_0.7confidence_0.3skip_threshold_t4GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision tensorflow-datasets

In [2]:
#Load the ImaageNetV2 TopImages split
import tensorflow_datasets as tfds
import torch
from torch.utils.data import IterableDataset, DataLoader
from torchvision import transforms
from PIL import Image

# 1) Download & prepare the TFDS builder for ImageNetV2 TopImages
builder = tfds.builder("imagenet_v2", config="topimages")
builder.download_and_prepare()

Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/imagenet_v2/topimages/3.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/imagenet_v2/topimages/incomplete.QGL1U7_3.0.0/imagenet_v2-test.tfrecord*..…

Dataset imagenet_v2 downloaded and prepared to /root/tensorflow_datasets/imagenet_v2/topimages/3.0.0. Subsequent calls will reuse this data.


In [3]:
#2 Load the 'test' split (10000 images, top‑images variant)
tfds_ds = builder.as_dataset(split="test", as_supervised=True)

In [4]:
#3 Define a tiny IterableDataset wrapper
class ImageNetV2TopImages(IterableDataset):
    def __init__(self, tfds_dataset, transform=None):
        self.ds = tfds_dataset
        self.transform = transform

    def __iter__(self):
        for img, label in tfds.as_numpy(self.ds):
            # img: HWC uint8 array, label: int
            pil = Image.fromarray(img)
            if self.transform:
                pil = self.transform(pil)
            yield pil, label

In [5]:
#4 Torch transforms (same as ResNet expects)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225]),
])

In [6]:
#5 Instantiate DataLoader
dataset = ImageNetV2TopImages(tfds_ds, transform=preprocess)
loader  = DataLoader(dataset, batch_size=1, num_workers=4)



In [7]:
#Define models and cascade logic
import time
import torch.nn.functional as F
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained ResNets
resnet18  = models.resnet18(pretrained=True).to(device).eval()    #model A
resnet34  = models.resnet34(pretrained=True).to(device).eval()    #model B
resnet50  = models.resnet50(pretrained=True).to(device).eval()    #model C
resnet152 = models.resnet152(pretrained=True).to(device).eval()   #model D

# Thresholds
SKIP_THRESH_B = 0.3   # if ResNet18 conf < 0.3, skip ResNet34
SKIP_THRESH_C = 0.3   # if ResNet34 conf < 0.3, skip ResNet50
THR_A = 0.7 # IDK threshold for model A
THR_B = 0.7 # IDK threshold for model B
THR_C = 0.7 # IDK threshold for model C

def predict(model, x):
    logits = model(x)
    probs  = F.softmax(logits, dim=1)
    conf, cls = torch.max(probs, dim=1)
    return cls.item(), conf.item()

def get_prediction(model, x, thr=None):
    cls, conf = predict(model, x)
    if thr is not None and conf < thr:
        return "IDK", conf
    return cls, conf

def dynamic_abcd_cascade(x):
    # A: ResNet18
    cls_a, conf_a = get_prediction(resnet18, x, THR_A)
    if cls_a != "IDK":
        return cls_a, conf_a, "ResNet18"

    skipped_b = False
    skipped_c = False
    conf_b = 0.0     # Default

    # maybe skip B if A’s confidence is really low
    if conf_a < SKIP_THRESH_B:
        skipped_b = True
    else:
        # B: ResNet34
        cls_b, conf_b = get_prediction(resnet34, x, THR_B)
        if cls_b != "IDK":
            return cls_b, conf_b, "ResNet34"
        skipped_b = False

    # maybe skip C if B was IDK with very low conf
    if skipped_b or (conf_b < SKIP_THRESH_C):
        skipped_c = True
    else:
        # C: ResNet50
        cls_c, conf_c = get_prediction(resnet50, x, THR_C)
        if cls_c != "IDK":
            return cls_c, conf_c, "ResNet50"
        skipped_c = False

    # D: ResNet152 (no threshold)
    cls_d, conf_d = predict(resnet152, x)
    branch_name = "ResNet152"
    # annotate if we skipped C or (B&C) on this path
    if skipped_b and skipped_c:
        branch_name += " (skipped B,C)"
    elif skipped_c:
        branch_name += " (skipped C)"
    return cls_d, conf_d, branch_name

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 167MB/s]
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 101MB/s]
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 169MB/s]
Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:01<00:00, 180MB/s]


In [8]:
from collections import Counter
import time

# Prepare counters
branch_sum = Counter()
branch_cnt = Counter()

# RUN & EVALUATE
MAX_IMAGES = 10000
total      = 0
correct    = 0
sum_time   = 0.0

for imgs, labels in loader:
    total += 1
    imgs, labels = imgs.to(device), labels.to(device)

    start = time.time()
    pred, conf, branch = dynamic_abcd_cascade(imgs)
    elapsed = time.time() - start

    # Update overall stats
    correct += (pred == labels.item())
    sum_time += elapsed

    # Update branch‑specific stats
    branch_sum[branch] += elapsed
    branch_cnt[branch] += 1

    # Light logging
    if total % 500 == 0:
        print(f"[{total:5d}] Pred={pred:4d}  Used={branch:<20s}  Time={elapsed*1000:.1f}ms")

    if total >= MAX_IMAGES:
        break

# Core metrics
accuracy  = correct / total
avg_time  = sum_time / total

# Valid skip rates
skip_C_to_D   = branch_cnt.get("ResNet152 (skipped C)", 0) / total
skip_BC_to_D  = branch_cnt.get("ResNet152 (skipped B,C)", 0) / total

# Print summary
print("\n=== SUMMARY ===")
print(f"Total images         : {total}")
print(f"Accuracy             : {accuracy*100:.2f}%")
print(f"Avg. time per input  : {avg_time*1000:.1f} ms")
print(f"Skip C→D rate        : {skip_C_to_D*100:.2f}%")
print(f"Skip B,C→D rate      : {skip_BC_to_D*100:.2f}%")

# Per-branch breakdown
for branch, cnt in branch_cnt.items():
    t = branch_sum[branch]
    print(f"  {branch:>20s} | Count: {cnt:5d} | Avg time: {t/cnt*1000:.1f} ms")

[  500] Pred= 230  Used=ResNet152             Time=55.8ms
[ 1000] Pred= 225  Used=ResNet34              Time=15.1ms
[ 1500] Pred=  35  Used=ResNet18              Time=6.2ms
[ 2000] Pred= 244  Used=ResNet18              Time=5.8ms
[ 2500] Pred= 786  Used=ResNet18              Time=4.5ms
[ 3000] Pred= 420  Used=ResNet18              Time=5.3ms
[ 3500] Pred= 895  Used=ResNet18              Time=4.9ms
[ 4000] Pred= 999  Used=ResNet34              Time=22.0ms
[ 4500] Pred=  17  Used=ResNet18              Time=4.2ms
[ 5000] Pred= 200  Used=ResNet152             Time=42.2ms
[ 5500] Pred= 981  Used=ResNet34              Time=11.3ms
[ 6000] Pred= 457  Used=ResNet18              Time=4.9ms
[ 6500] Pred= 939  Used=ResNet152 (skipped B,C)  Time=30.3ms
[ 7000] Pred= 710  Used=ResNet18              Time=5.1ms
[ 7500] Pred= 977  Used=ResNet50              Time=20.1ms
[ 8000] Pred= 914  Used=ResNet152             Time=70.8ms
[ 8500] Pred= 760  Used=ResNet18              Time=4.2ms
[ 9000] Pred= 685  U