<a href="https://colab.research.google.com/github/RH00000/UH_RTS_Research_ML/blob/main/first_idk_cascade.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
!pip install torch torchvision tensorflow-datasets



In [19]:
#Load the ImaageNetV2 TopImages split
import tensorflow_datasets as tfds
import torch
from torch.utils.data import IterableDataset, DataLoader
from torchvision import transforms
from PIL import Image

# 1) Download & prepare the TFDS builder for ImageNetV2 TopImages
builder = tfds.builder("imagenet_v2", config="topimages")
builder.download_and_prepare()

Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/imagenet_v2/topimages/3.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/imagenet_v2/topimages/incomplete.BOP971_3.0.0/imagenet_v2-test.tfrecord*..…

Dataset imagenet_v2 downloaded and prepared to /root/tensorflow_datasets/imagenet_v2/topimages/3.0.0. Subsequent calls will reuse this data.


In [20]:
#2 Load the 'test' split (10000 images, top‑images variant)
tfds_ds = builder.as_dataset(split="test", as_supervised=True)

In [21]:
#3 Define a tiny IterableDataset wrapper
class ImageNetV2TopImages(IterableDataset):
    def __init__(self, tfds_dataset, transform=None):
        self.ds = tfds_dataset
        self.transform = transform

    def __iter__(self):
        for img, label in tfds.as_numpy(self.ds):
            # img: HWC uint8 array, label: int
            pil = Image.fromarray(img)
            if self.transform:
                pil = self.transform(pil)
            yield pil, label

In [22]:
#4 Torch transforms (same as ResNet expects)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225]),
])

In [23]:
#5 Instantiate DataLoader
dataset = ImageNetV2TopImages(tfds_ds, transform=preprocess)
loader  = DataLoader(dataset, batch_size=1, num_workers=4)



In [24]:
#Define models and cascade logic
import time
import torch.nn.functional as F
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained ResNets
resnet18  = models.resnet18(pretrained=True).to(device).eval()    #model A
resnet34  = models.resnet34(pretrained=True).to(device).eval()    #model B
resnet152 = models.resnet152(pretrained=True).to(device).eval()   #model C

# Thresholds
SKIP_THRESH = 0.3   # if ResNet18 conf < 0.3, skip ResNet34
THR_A = 0.7 # IDK threshold for model A
THR_B = 0.8 # IDK thresholds for model B

def predict(model, x):
    logits = model(x)
    probs  = F.softmax(logits, dim=1)
    conf, cls = torch.max(probs, dim=1)
    return cls.item(), conf.item()

def get_prediction(model, x, thr=None):
    cls, conf = predict(model, x)
    if thr is not None and conf < thr:
        return "IDK", conf
    return cls, conf

def dynamic_idk_cascade(x):
    # A: ResNet18
    cls_a, conf_a = get_prediction(resnet18, x, THR_A)
    if cls_a != "IDK":
        return cls_a, conf_a, "ResNet18"
    # Skip logic
    if conf_a < SKIP_THRESH:
        cls_c, conf_c = predict(resnet152, x)
        return cls_c, conf_c, "ResNet152 (skipped B)"
    # B: ResNet34
    cls_b, conf_b = get_prediction(resnet34, x, THR_B)
    if cls_b != "IDK":
        return cls_b, conf_b, "ResNet34"
    # C: ResNet152
    cls_c, conf_c = predict(resnet152, x)
    return cls_c, conf_c, "ResNet152"

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 112MB/s]
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 91.5MB/s]
Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:02<00:00, 95.8MB/s]


In [25]:
# RUN & EVALUATE
from collections import Counter

total      = 0
correct    = 0
sum_time   = 0.0
branch_sum = Counter()      # accumulate time per branch
branch_cnt = Counter()      # count per branch

for imgs, labels in loader:
    imgs, labels = imgs.to(device), labels.to(device)
    start = time.time()
    pred, conf, model_used = dynamic_idk_cascade(imgs)
    elapsed = time.time() - start

    total   += 1
    correct += (pred == labels.item())
    sum_time += elapsed

    branch_sum[model_used] += elapsed
    branch_cnt[model_used] += 1

    # optional per‐image print
    print(f"[{total:5d}] GT={labels.item():4d}  Pred={pred:4d}  "
          f"Conf={conf:.2f}  Used={model_used:<20s}  "
          f"Time={elapsed*1000:.1f}ms")

# Core metrics
accuracy      = correct / total
avg_time      = sum_time / total          # seconds per image

# Skip rate = fraction of times we jumped from A straight to C
skip_rate     = branch_cnt["ResNet152 (skipped B)"] / total

# Print summary
print("\n=== SUMMARY ===")
print(f"Total images         : {total}")
print(f"Accuracy             : {accuracy*100:.2f}%")
print(f"Avg. time per input : {avg_time*1000:.1f} ms")
print(f"Skip rate (A→C)      : {skip_rate*100:.2f}%\n")

# Breakdown by branch
for branch in branch_cnt:
    cnt = branch_cnt[branch]
    t   = branch_sum[branch]
    print(f"  {branch:>20s} | Count: {cnt:5d} | "
          f"Avg time: {t/cnt*1000:.1f} ms")




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[ 6439] GT= 533  Pred= 533  Conf=0.86  Used=ResNet18              Time=89.4ms
[ 6440] GT= 533  Pred= 533  Conf=0.86  Used=ResNet18              Time=84.1ms
[ 6441] GT= 514  Pred= 514  Conf=0.97  Used=ResNet18              Time=83.5ms
[ 6442] GT= 514  Pred= 514  Conf=0.97  Used=ResNet18              Time=88.8ms
[ 6443] GT= 514  Pred= 514  Conf=0.97  Used=ResNet18              Time=84.4ms
[ 6444] GT= 514  Pred= 514  Conf=0.97  Used=ResNet18              Time=85.8ms
[ 6445] GT= 625  Pred= 625  Conf=1.00  Used=ResNet18              Time=92.2ms
[ 6446] GT= 625  Pred= 625  Conf=1.00  Used=ResNet18              Time=102.7ms
[ 6447] GT= 625  Pred= 625  Conf=1.00  Used=ResNet18              Time=109.8ms
[ 6448] GT= 625  Pred= 625  Conf=1.00  Used=ResNet18              Time=84.4ms
[ 6449] GT=  76  Pred=  77  Conf=0.82  Used=ResNet18              Time=84.0ms
[ 6450] GT=  76  Pred=  77  Conf=0.82  Used=ResNet18              Time=88.1

KeyboardInterrupt: 

In [26]:
print(f"total: {total}, correct: {correct}")
print(f"branches seen: {list(branch_cnt.keys())}")


total: 11438, correct: 9110
branches seen: ['ResNet152 (skipped B)', 'ResNet18', 'ResNet152', 'ResNet34']


In [27]:
# Calculate metrics based on partial run
accuracy      = correct / total if total > 0 else 0
avg_time      = sum_time / total if total > 0 else 0
skip_rate     = branch_cnt["ResNet152 (skipped B)"] / total if "ResNet152 (skipped B)" in branch_cnt else 0

print("\n=== PARTIAL SUMMARY ===")
print(f"Images processed      : {total}")
print(f"Accuracy              : {accuracy * 100:.2f}%")
print(f"Avg time per input    : {avg_time * 1000:.1f} ms")
print(f"Skip rate (A → C)     : {skip_rate * 100:.2f}%\n")

for branch in branch_cnt:
    count = branch_cnt[branch]
    total_time = branch_sum[branch]
    avg_branch_time = total_time / count * 1000
    print(f"  {branch:>20s} | Count: {count:5d} | Avg time: {avg_branch_time:.1f} ms")



=== PARTIAL SUMMARY ===
Images processed      : 11438
Accuracy              : 79.65%
Avg time per input    : 299.3 ms
Skip rate (A → C)     : 10.11%

  ResNet152 (skipped B) | Count:  1156 | Avg time: 598.5 ms
              ResNet18 | Count:  6954 | Avg time: 97.8 ms
             ResNet152 | Count:  2380 | Avg time: 755.5 ms
              ResNet34 | Count:   948 | Avg time: 266.9 ms
