<a href="https://colab.research.google.com/github/RH00000/UH_RTS_Research_ML/blob/main/non_skipping_0.95confidence_t4GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision tensorflow-datasets

In [6]:
#Load the ImaageNetV2 TopImages split
import tensorflow_datasets as tfds
import torch
from torch.utils.data import IterableDataset, DataLoader
from torchvision import transforms
from PIL import Image

# 1) Download & prepare the TFDS builder for ImageNetV2 TopImages
builder = tfds.builder("imagenet_v2", config="topimages")
builder.download_and_prepare()

In [None]:
#2 Load the 'test' split (10000 images, top‑images variant)
tfds_ds = builder.as_dataset(split="test", as_supervised=True)

In [None]:
#3 Define a tiny IterableDataset wrapper
class ImageNetV2TopImages(IterableDataset):
    def __init__(self, tfds_dataset, transform=None):
        self.ds = tfds_dataset
        self.transform = transform

    def __iter__(self):
        for img, label in tfds.as_numpy(self.ds):
            # img: HWC uint8 array, label: int
            pil = Image.fromarray(img)
            if self.transform:
                pil = self.transform(pil)
            yield pil, label

In [7]:
#4 Torch transforms (same as ResNet expects)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225]),
])

In [8]:
#5 Instantiate DataLoader
dataset = ImageNetV2TopImages(tfds_ds, transform=preprocess)
loader  = DataLoader(dataset, batch_size=1, num_workers=4)



In [9]:
#Define models and cascade logic
import time
import torch.nn.functional as F
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained ResNets
resnet18  = models.resnet18(pretrained=True).to(device).eval()    #model A
resnet34  = models.resnet34(pretrained=True).to(device).eval()    #model B
resnet50  = models.resnet50(pretrained=True).to(device).eval()    #model C
resnet152 = models.resnet152(pretrained=True).to(device).eval()   #model D

# Thresholds
THR_A = 0.95 # IDK threshold for model A
THR_B = 0.95 # IDK threshold for model B
THR_C = 0.95 # IDK threshold for model C

def predict(model, x):
    logits = model(x)
    probs = F.softmax(logits, dim=1)
    conf, cls = torch.max(probs, dim=1)
    return cls.item(), conf.item()

def get_prediction(model, x, thr=None):
    cls, conf = predict(model, x)
    if thr is not None and conf < thr:
        return "IDK", conf
    return cls, conf

def nonskip_abcd_cascade(x):
    # A: ResNet18
    cls_a, conf_a = get_prediction(resnet18, x, THR_A)
    if cls_a != "IDK":
        return cls_a, conf_a, "ResNet18"

    # B: ResNet34 (run if A returned IDK)
    cls_b, conf_b = get_prediction(resnet34, x, THR_B)
    if cls_b != "IDK":
        return cls_b, conf_b, "ResNet34"

    # C: ResNet50 (run if B returned IDK)
    cls_c, conf_c = get_prediction(resnet50, x, THR_C)
    if cls_c != "IDK":
        return cls_c, conf_c, "ResNet50"

    # D: ResNet152 (fallback, no threshold)
    cls_d, conf_d = predict(resnet152, x)
    return cls_d, conf_d, "ResNet152"

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 126MB/s]
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 111MB/s]
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 55.3MB/s]
Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:01<00:00, 172MB/s]


In [10]:
from collections import Counter
import time

# Prepare counters
branch_sum = Counter()
branch_cnt = Counter()

# RUN & EVALUATE
MAX_IMAGES = 10000
total      = 0
correct    = 0
sum_time   = 0.0

for imgs, labels in loader:
    total += 1
    imgs, labels = imgs.to(device), labels.to(device)

    start = time.time()
    pred, conf, branch = nonskip_abcd_cascade(imgs)
    elapsed = time.time() - start

    # Update overall stats
    correct += (pred == labels.item())
    sum_time += elapsed

    # Update branch‑specific stats
    branch_sum[branch] += elapsed
    branch_cnt[branch] += 1

    # Light logging
    if total % 500 == 0:
        print(f"[{total:5d}] Pred={pred:4d}  Used={branch:<20s}  Time={elapsed*1000:.1f}ms")

    if total >= MAX_IMAGES:
        break

# Core metrics
accuracy  = correct / total
avg_time  = sum_time / total

# Print summary
print("\n=== SUMMARY ===")
print(f"Total images        : {total}")
print(f"Accuracy            : {accuracy*100:.2f}%")
print(f"Avg. time per input : {avg_time*1000:.1f} ms")

# Per‑branch breakdown
for branch, cnt in branch_cnt.items():
    t = branch_sum[branch]
    print(f"  {branch:>20s} | Count: {cnt:5d} | Avg time: {t/cnt*1000:.1f} ms")

[  500] Pred= 230  Used=ResNet152             Time=42.5ms
[ 1000] Pred= 225  Used=ResNet34              Time=12.9ms
[ 1500] Pred=  35  Used=ResNet50              Time=24.5ms
[ 2000] Pred= 244  Used=ResNet18              Time=4.5ms
[ 2500] Pred= 763  Used=ResNet152             Time=53.6ms
[ 3000] Pred= 420  Used=ResNet18              Time=4.6ms
[ 3500] Pred= 404  Used=ResNet152             Time=41.4ms
[ 4000] Pred= 700  Used=ResNet152             Time=44.9ms
[ 4500] Pred=  17  Used=ResNet18              Time=4.4ms
[ 5000] Pred= 200  Used=ResNet152             Time=56.0ms
[ 5500] Pred= 429  Used=ResNet152             Time=48.3ms
[ 6000] Pred= 457  Used=ResNet18              Time=6.9ms
[ 6500] Pred= 939  Used=ResNet152             Time=43.0ms
[ 7000] Pred= 710  Used=ResNet50              Time=20.0ms
[ 7500] Pred= 977  Used=ResNet152             Time=43.3ms
[ 8000] Pred= 914  Used=ResNet152             Time=42.5ms
[ 8500] Pred= 760  Used=ResNet50              Time=21.0ms
[ 9000] Pred= 685 