In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F

class NPZFrameDataset(Dataset):
    def __init__(self, npz_dir, files, transform=None, binary=True):
        """
        Args:
            npz_dir (str): path to .npz files
            files (list[str]): list of filenames
            transform (callable): torchvision-style transform (on torch.Tensor CxHxW)
            binary (bool): if True → merge class 2 into 1 (binary classification)
        """
        self.samples = []
        self.transform = transform
        self.binary = binary

        for f in files:
            case = np.load(os.path.join(npz_dir, f))
            images = case["image"].astype(np.float32)   # (F,H,W)
            labels = case["label"].astype(np.int64)     # (F,)

            # normalize to [0,1] at load time
            images = images / 255.0

            if self.binary:
                labels[labels == 2] = 1   # merge suboptimal with optimal

            # flatten into (frame, label) list
            for img, lbl in zip(images, labels):
                self.samples.append((img, lbl))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img, lbl = self.samples[idx]

        # (H,W) → (1,H,W)
        img = torch.from_numpy(img).unsqueeze(0)  # torch.float32

        # resize to (1,224,224)
        img = F.interpolate(
            img.unsqueeze(0),
            size=(224, 224),
            mode="bilinear",
            align_corners=False
        ).squeeze(0)

        if self.transform is not None:
            img = self.transform(img)

        return img, torch.tensor(lbl).long()


In [2]:
import torchvision.models as models
import torch.nn as nn

class FrameClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.backbone = models.resnet50(pretrained=True)
        # change first conv → grayscale
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)

    def forward(self, x):   # (B,1,H,W)
        return self.backbone(x)


In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tqdm import tqdm


# ==== Dataset with Augmentation ====

train_transform = T.Compose([
    # --- Geometric ---
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.2),
    T.RandomRotation(degrees=15),
    T.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.9, 1.1)),

    # --- Photometric ---
    T.ColorJitter(brightness=0.2, contrast=0.2),
    T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.3),
    T.RandomApply([T.GaussianBlur(kernel_size=5)], p=0.1),
    T.RandomAdjustSharpness(sharpness_factor=2, p=0.3),

    # --- Occlusion / noise (tensor ops) ---
    T.RandomErasing(p=0.25, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0),
    T.RandomApply([T.Lambda(lambda x: (x + 0.05*torch.randn_like(x)).clamp(0, 1))], p=0.2),

    # --- Final (no ToTensor here) ---
    T.Normalize(mean=[0.5], std=[0.5])  # for 1-channel tensors in [0,1]
])

val_transform = T.Compose([
    T.Normalize(mean=[0.5], std=[0.5])
])


train_dataset = NPZFrameDataset("D:/dataset/npz_80_tiny",  
                                sorted(os.listdir("D:/dataset/npz_80_tiny"))[:210],  
                                transform=train_transform)
val_dataset = NPZFrameDataset("D:/dataset/converted_classifier_npz_compact", 
                            sorted(os.listdir("D:/dataset/converted_classifier_npz_compact"))[210:255], 
                            transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = FrameClassifier(num_classes=2).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)  # weight decay added
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.5)



## Training LOOP (DONT RUN)

In [8]:
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import classification_report

# ---- Config ----
best_acc = 0.0
patience = 5
no_improve = 0
num_epochs = 30
save_path = "D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_tiny.pth"

scaler = torch.cuda.amp.GradScaler() if device == "cuda" else None

for epoch in range(num_epochs):
    # ----------------------
    # Training
    # ----------------------
    model.train()
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    total_loss = 0.0

    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}: Avg Train Loss = {avg_loss:.4f}")

    # ----------------------
    # Validation
    # ----------------------
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_loss /= len(val_loader)
    val_acc = 100 * correct / total

    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    print(classification_report(all_labels, all_preds, digits=4))

    # ----------------------
    # Scheduler + Early Stopping
    # ----------------------
    scheduler.step(val_acc)

    if val_acc > best_acc:
        best_acc = val_acc
        no_improve = 0
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "best_acc": best_acc
        }, save_path)
        print(f"✅ Best model saved with val acc = {best_acc:.2f}%")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("⏹️ Early stopping triggered.")
            break

print(f"Training finished. Best validation accuracy = {best_acc:.2f}%")


  scaler = torch.cuda.amp.GradScaler() if device == "cuda" else None
  with torch.cuda.amp.autocast():
Epoch 1: 100%|██████████| 4200/4200 [02:34<00:00, 27.10it/s, loss=0.263] 


Epoch 1: Avg Train Loss = 0.3681
Val Loss: 0.1217 | Val Acc: 95.77%
              precision    recall  f1-score   support

           0     0.9905    0.9659    0.9780     36852
           1     0.3253    0.6382    0.4309       948

    accuracy                         0.9577     37800
   macro avg     0.6579    0.8021    0.7045     37800
weighted avg     0.9738    0.9577    0.9643     37800

✅ Best model saved with val acc = 95.77%


  with torch.cuda.amp.autocast():
Epoch 2: 100%|██████████| 4200/4200 [02:38<00:00, 26.43it/s, loss=0.263]  


Epoch 2: Avg Train Loss = 0.2713
Val Loss: 0.2322 | Val Acc: 90.43%
              precision    recall  f1-score   support

           0     0.9979    0.9037    0.9485     36852
           1     0.1983    0.9262    0.3267       948

    accuracy                         0.9043     37800
   macro avg     0.5981    0.9149    0.6376     37800
weighted avg     0.9778    0.9043    0.9329     37800



  with torch.cuda.amp.autocast():
Epoch 3: 100%|██████████| 4200/4200 [02:37<00:00, 26.64it/s, loss=0.0261] 


Epoch 3: Avg Train Loss = 0.2402
Val Loss: 0.1945 | Val Acc: 92.91%
              precision    recall  f1-score   support

           0     0.9971    0.9300    0.9624     36852
           1     0.2474    0.8945    0.3877       948

    accuracy                         0.9291     37800
   macro avg     0.6223    0.9123    0.6750     37800
weighted avg     0.9783    0.9291    0.9480     37800



  with torch.cuda.amp.autocast():
Epoch 4: 100%|██████████| 4200/4200 [02:37<00:00, 26.63it/s, loss=0.0933] 


Epoch 4: Avg Train Loss = 0.2256
Val Loss: 0.0912 | Val Acc: 96.69%
              precision    recall  f1-score   support

           0     0.9927    0.9732    0.9828     36852
           1     0.4090    0.7205    0.5218       948

    accuracy                         0.9669     37800
   macro avg     0.7008    0.8468    0.7523     37800
weighted avg     0.9780    0.9669    0.9713     37800

✅ Best model saved with val acc = 96.69%


  with torch.cuda.amp.autocast():
Epoch 5: 100%|██████████| 4200/4200 [02:38<00:00, 26.49it/s, loss=0.196]  


Epoch 5: Avg Train Loss = 0.2135
Val Loss: 0.0935 | Val Acc: 96.33%
              precision    recall  f1-score   support

           0     0.9945    0.9677    0.9809     36852
           1     0.3866    0.7911    0.5194       948

    accuracy                         0.9633     37800
   macro avg     0.6905    0.8794    0.7502     37800
weighted avg     0.9792    0.9633    0.9693     37800



  with torch.cuda.amp.autocast():
Epoch 6: 100%|██████████| 4200/4200 [02:36<00:00, 26.79it/s, loss=0.527]  


Epoch 6: Avg Train Loss = 0.1992
Val Loss: 0.2084 | Val Acc: 91.03%
              precision    recall  f1-score   support

           0     0.9987    0.9092    0.9518     36852
           1     0.2127    0.9536    0.3478       948

    accuracy                         0.9103     37800
   macro avg     0.6057    0.9314    0.6498     37800
weighted avg     0.9790    0.9103    0.9367     37800



  with torch.cuda.amp.autocast():
Epoch 7: 100%|██████████| 4200/4200 [02:38<00:00, 26.43it/s, loss=0.00913]


Epoch 7: Avg Train Loss = 0.1942
Val Loss: 0.1212 | Val Acc: 95.11%
              precision    recall  f1-score   support

           0     0.9963    0.9534    0.9744     36852
           1     0.3223    0.8618    0.4691       948

    accuracy                         0.9511     37800
   macro avg     0.6593    0.9076    0.7217     37800
weighted avg     0.9794    0.9511    0.9617     37800



  with torch.cuda.amp.autocast():
Epoch 8: 100%|██████████| 4200/4200 [02:38<00:00, 26.52it/s, loss=0.0106] 


Epoch 8: Avg Train Loss = 0.1602
Val Loss: 0.0574 | Val Acc: 97.63%
              precision    recall  f1-score   support

           0     0.9933    0.9824    0.9878     36852
           1     0.5196    0.7416    0.6110       948

    accuracy                         0.9763     37800
   macro avg     0.7564    0.8620    0.7994     37800
weighted avg     0.9814    0.9763    0.9783     37800

✅ Best model saved with val acc = 97.63%


  with torch.cuda.amp.autocast():
Epoch 9: 100%|██████████| 4200/4200 [02:35<00:00, 27.08it/s, loss=0.00912]


Epoch 9: Avg Train Loss = 0.1467
Val Loss: 0.1353 | Val Acc: 95.28%
              precision    recall  f1-score   support

           0     0.9980    0.9534    0.9752     36852
           1     0.3385    0.9262    0.4958       948

    accuracy                         0.9528     37800
   macro avg     0.6682    0.9398    0.7355     37800
weighted avg     0.9815    0.9528    0.9632     37800



  with torch.cuda.amp.autocast():
Epoch 10: 100%|██████████| 4200/4200 [02:38<00:00, 26.58it/s, loss=0.00969] 


Epoch 10: Avg Train Loss = 0.1399
Val Loss: 0.0820 | Val Acc: 96.72%
              precision    recall  f1-score   support

           0     0.9952    0.9710    0.9830     36852
           1     0.4211    0.8186    0.5561       948

    accuracy                         0.9672     37800
   macro avg     0.7081    0.8948    0.7695     37800
weighted avg     0.9808    0.9672    0.9723     37800



  with torch.cuda.amp.autocast():
Epoch 11: 100%|██████████| 4200/4200 [02:38<00:00, 26.55it/s, loss=0.0158]  


Epoch 11: Avg Train Loss = 0.1343
Val Loss: 0.0876 | Val Acc: 96.59%
              precision    recall  f1-score   support

           0     0.9969    0.9680    0.9823     36852
           1     0.4154    0.8829    0.5650       948

    accuracy                         0.9659     37800
   macro avg     0.7061    0.9255    0.7736     37800
weighted avg     0.9823    0.9659    0.9718     37800



  with torch.cuda.amp.autocast():
Epoch 12: 100%|██████████| 4200/4200 [02:38<00:00, 26.58it/s, loss=0.0224]  


Epoch 12: Avg Train Loss = 0.1121
Val Loss: 0.0689 | Val Acc: 97.13%
              precision    recall  f1-score   support

           0     0.9959    0.9746    0.9851     36852
           1     0.4609    0.8449    0.5964       948

    accuracy                         0.9713     37800
   macro avg     0.7284    0.9098    0.7908     37800
weighted avg     0.9825    0.9713    0.9754     37800



  with torch.cuda.amp.autocast():
Epoch 13: 100%|██████████| 4200/4200 [02:38<00:00, 26.44it/s, loss=0.0171]  


Epoch 13: Avg Train Loss = 0.1080
Val Loss: 0.0910 | Val Acc: 96.50%
              precision    recall  f1-score   support

           0     0.9969    0.9671    0.9818     36852
           1     0.4084    0.8840    0.5587       948

    accuracy                         0.9650     37800
   macro avg     0.7027    0.9255    0.7702     37800
weighted avg     0.9822    0.9650    0.9712     37800

⏹️ Early stopping triggered.
Training finished. Best validation accuracy = 97.63%


In [4]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

# reload model with same architecture
model = FrameClassifier(num_classes=2).to(device)
checkpoint = torch.load("D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_tiny.pth", map_location=device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

FrameClassifier(
  (backbone): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
 

In [13]:
import numpy as np
import torch.nn.functional as F

def predict_case(npz_path, model, device="cpu", batch_size=16):
    data = np.load(npz_path, allow_pickle=True)
    images = data["image"].astype(np.float32)  # (T,H,W)
    labels = data["label"].astype(np.int64)    # ground-truth (optional)

    # normalize [0,1]
    images = (images - images.min()) / (images.max() - images.min() + 1e-8)

    preds_all = []
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            batch = images[i:i+batch_size]
            batch = torch.tensor(batch).unsqueeze(1).to(device)  # (B,1,H,W)

            # resize to 224x224
            batch = F.interpolate(batch, size=(224,224),
                                  mode="bilinear", align_corners=False)

            outputs = model(batch)
            preds = outputs.argmax(1).cpu().numpy()
            preds_all.extend(preds)

    return np.array(preds_all), labels


In [11]:
import os
from sklearn.metrics import classification_report

TEST_DIR = "D:/dataset/converted_classifier_npz_compact"  # full 840-frame cases
test_files = sorted(os.listdir(TEST_DIR))[-45:]  # example: last 45 for test

all_preds, all_labels = [], []
for f in test_files:
    if not f.endswith(".npz"): continue
    preds, labels = predict_case(os.path.join(TEST_DIR,f), model, device=device)
    all_preds.extend(preds)
    all_labels.extend(labels)

print("Test Classification Report:")
print(classification_report(all_labels, all_preds, digits=4))


Test Classification Report:
              precision    recall  f1-score   support

           0     0.9740    1.0000    0.9868     36818
           1     0.0000    0.0000    0.0000       240
           2     0.0000    0.0000    0.0000       742

    accuracy                         0.9740     37800
   macro avg     0.3247    0.3333    0.3289     37800
weighted avg     0.9487    0.9740    0.9612     37800



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [14]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

# ---- Load best checkpoint ----
checkpoint = torch.load("D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_tiny.pth", map_location=device)
model.load_state_dict(checkpoint["model_state"])
model.eval()

# ---- Collect predictions ----
all_preds, all_labels = [], []

with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Inference"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# ---- Metrics ----
print("Classification Report (Frame-level):")
print(classification_report(all_labels, all_preds, target_names=["Background","Positive"], digits=4))

print("Confusion Matrix:")
print(confusion_matrix(all_labels, all_preds))


Inference: 100%|██████████| 9450/9450 [01:30<00:00, 104.06it/s]


Classification Report (Frame-level):
              precision    recall  f1-score   support

  Background     0.9933    0.9824    0.9878     36852
    Positive     0.5196    0.7416    0.6110       948

    accuracy                         0.9763     37800
   macro avg     0.7564    0.8620    0.7994     37800
weighted avg     0.9814    0.9763    0.9783     37800

Confusion Matrix:
[[36202   650]
 [  245   703]]


In [13]:
import torch.nn.functional as F

def pick_best_frame(npz_path, batch_size=64):
    case = np.load(npz_path, mmap_mode="r")
    images = case["image"]   # (F,H,W)
    labels = case["label"]   # (F,)

    best_idx, best_score = None, -1.0

    for start in range(0, len(images), batch_size):
        batch = images[start:start+batch_size].astype(np.float32)
        b = torch.from_numpy(batch).unsqueeze(1)  # (B,1,H,W)

        # normalize to [0,1]
        b_min = b.amin(dim=(2,3), keepdim=True)
        b_max = b.amax(dim=(2,3), keepdim=True)
        b = (b - b_min) / (b_max - b_min + 1e-8)

        # resize to match training (224,224)
        b = F.interpolate(b, size=(224,224), mode="bilinear", align_corners=False)
        b = (b - 0.5) / 0.5   # normalization
        b = b.to(device)

        with torch.no_grad():
            logits = model(b)
            probs = torch.softmax(logits, dim=1)[:,1]  # positive prob

        max_prob, max_idx = torch.max(probs, dim=0)
        if max_prob.item() > best_score:
            best_score = max_prob.item()
            best_idx = start + max_idx.item()

    # GT positives
    positives = np.where(labels > 0)[0]
    return best_idx, best_score, positives


In [8]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

def evaluate_case_level(val_files, root_dir, model, device, batch_size=64):
    model.eval()
    total_cases, case_hits = 0, 0

    for fname in val_files:
        case = np.load(os.path.join(root_dir, fname))
        images = case["image"].astype(np.float32)   # (F,H,W)
        labels = case["label"].astype(np.int64)     # (F,)
        labels[labels == 2] = 1   # remap

        positives = np.where(labels == 1)[0]

        best_idx, best_score = None, -1.0
        for start in range(0, len(images), batch_size):
            batch = images[start:start+batch_size]
            b = torch.from_numpy(batch).unsqueeze(1)  # (B,1,H,W)
            b = (b - b.min()) / (b.max() - b.min() + 1e-8)  # normalize per-batch
            b = F.interpolate(b, size=(224,224), mode="bilinear", align_corners=False)
            b = (b - 0.5) / 0.5  # [-1,1]
            b = b.to(device)

            with torch.no_grad():
                logits = model(b)
                probs = torch.softmax(logits, dim=1)[:,1]  # probability of positive

            max_prob, max_idx = torch.max(probs, dim=0)
            if max_prob.item() > best_score:
                best_score = max_prob.item()
                best_idx = start + max_idx.item()

        # case success? (did we catch a positive if exists)
        if len(positives) > 0:
            success = best_idx in positives
        else:
            success = (labels[best_idx] == 0)

        total_cases += 1
        case_hits += int(success)

    case_acc = 100 * case_hits / total_cases
    print(f"✅ Case-level Hit Rate: {case_hits}/{total_cases} ({case_acc:.2f}%)")
    return case_acc


# === During Validation Loop ===
model.eval()
all_preds, all_labels = [], []
val_loss, correct, total = 0.0, 0, 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        val_loss += loss.item()

        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

val_loss /= len(val_loader)
val_acc = 100 * correct / total

print("\n--- Frame-level Report ---")
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
print(classification_report(all_labels, all_preds, digits=4))
print("Confusion Matrix:")
print(confusion_matrix(all_labels, all_preds))

print("\n--- Case-level Report ---")
val_files = sorted(os.listdir("D:/dataset/converted_classifier_npz_compact"))[210:255]
evaluate_case_level(val_files, "D:/dataset/converted_classifier_npz_compact", model, device)


NameError: name 'val_loader' is not defined

In [5]:
import os, random, numpy as np, torch
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(1337)

# ---- Load model from checkpoint (dict) ----
model = FrameClassifier(num_classes=2).to(device)
ckpt = torch.load("D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier.pth",
                  map_location=device)
state = ckpt["model_state"] if isinstance(ckpt, dict) and "model_state" in ckpt else ckpt
model.load_state_dict(state)
model.eval()

@torch.no_grad()
def _preprocess_batch_uint8_to_model(batch_uint8: np.ndarray, to224=True):
    """
    batch_uint8: (B, H, W), dtype uint8 or float32 0..255
    returns: (B,1,224,224) float32 normalized like training (mean=0.5,std=0.5)
    """
    b = torch.from_numpy(batch_uint8).float()   # (B,H,W)
    b = b.unsqueeze(1)                          # (B,1,H,W)
    b = b / 255.0                               # [0,1]
    if to224:
        b = F.interpolate(b, size=(224,224), mode="bilinear", align_corners=False)
    # Normalize to match training: mean=0.5, std=0.5 -> [-1,1]
    b = (b - 0.5) / 0.5
    return b.to(device)

@torch.no_grad()
def pick_best_frame(npz_path, batch_size=64):
    case = np.load(npz_path, mmap_mode="r")
    images = case["image"]  # (F,H,W), uint8
    best_idx, best_score = None, -1.0

    for start in range(0, len(images), batch_size):
        batch = images[start:start+batch_size]          # uint8
        b = _preprocess_batch_uint8_to_model(batch)     # (B,1,224,224)
        logits = model(b)
        probs = torch.softmax(logits, dim=1)[:, 1]      # P(positive)
        max_prob, max_idx = torch.max(probs, dim=0)
        if max_prob.item() > best_score:
            best_score = max_prob.item()
            best_idx = start + max_idx.item()
    return best_idx, best_score

def get_label_array(case_npz):
    for key in ["label", "labels", "y", "gt", "target"]:
        if key in case_npz:
            arr = np.asarray(case_npz[key]).reshape(-1)
            return arr
    raise KeyError("No per-frame label key found (label/labels/y/gt/target).")

def evaluate_random_subset(root_dir, start_idx=0, k=12, tol=0):
    """
    tol = tolerance window (in frames). tol=0 requires exact positive frame.
    """
    files = sorted([f for f in os.listdir(root_dir) if f.endswith(".npz")])
    test_files = files[start_idx:]
    if not test_files:
        raise RuntimeError(f"No .npz files at/after index {start_idx} under {root_dir}")
    sample = random.sample(test_files, k=min(k, len(test_files)))

    total = 0
    hits_anypos = 0
    exact_argmax_hits = 0
    mean_prob_accum = []
    dist_to_nearest_pos = []

    print(f"Evaluating {len(sample)} files (tol={tol})...\n")

    for fname in sample:
        path = os.path.join(root_dir, fname)
        case = np.load(path, mmap_mode="r")
        labels = get_label_array(case)              # 0/1/2
        positives = np.where((labels == 1) | (labels == 2))[0]

        pred_idx, pred_prob = pick_best_frame(path)
        total += 1
        mean_prob_accum.append(pred_prob)

        # Any positive within tolerance?
        if positives.size > 0:
            dists = np.abs(positives - pred_idx)
            hit_tol = np.any(dists <= tol)
            hits_anypos += int(hit_tol)
            # "argmax" reference (first positive)
            gt_argmax = positives[0]
            exact_argmax_hits += int(abs(pred_idx - gt_argmax) <= tol)
            dist = int(dists.min())
        else:
            # no positives in GT: success if prediction is background
            hit_tol = (labels[pred_idx] == 0)
            hits_anypos += int(hit_tol)
            gt_argmax = int(np.argmax(labels))
            exact_argmax_hits += int(pred_idx == gt_argmax)
            dist = 0

        dist_to_nearest_pos.append(dist)

        print(f"{fname}")
        print(f"  Frames: {len(labels)} | Pred idx: {pred_idx:4d} (p={pred_prob:.4f})")
        print(f"  GT positives: {positives.tolist()[:12]}{' ...' if len(positives)>12 else ''}")
        print(f"  Hit (±{tol})? {'YES' if hit_tol else 'NO'} | Dist→nearest +ve: {dist}\n")

    acc_anypos = hits_anypos / total
    acc_exact  = exact_argmax_hits / total
    mean_prob  = float(np.mean(mean_prob_accum))
    mean_dist  = float(np.mean(dist_to_nearest_pos))

    print("====== Summary ======")
    print(f"Files evaluated:      {total}")
    print(f"Hit any positive (±{tol}): {hits_anypos}/{total}  ({acc_anypos*100:.1f}%)")
    print(f"Matches to first +ve (±{tol}): {exact_argmax_hits}/{total}  ({acc_exact*100:.1f}%)")
    print(f"Mean predicted prob:  {mean_prob:.4f}")
    print(f"Mean dist to +ve:     {mean_dist:.2f} frames")
    print("=====================")


In [None]:
# Evaluate on full-length compacts (840 frames), e.g. last 30 files:
evaluate_random_subset(
    root_dir="D:/dataset/converted_classifier_npz_compact",
    start_idx=255,  # whatever split you used
    k=45,
    tol=0      # allow ±3 frames tolerance
)

# Or evaluate on your tiny 80-frame set (fast smoke test)

evaluate_random_subset(
    root_dir="D:/dataset/npz_80_tiny",
    start_idx=255,
    k=45,
    tol=0
)


Evaluating 45 files (tol=0)...

ea86047a-bae4-464b-a2ed-015935bebb2a.npz
  Frames: 840 | Pred idx:   60 (p=0.9500)
  GT positives: [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61] ...
  Hit (±0)? YES | Dist→nearest +ve: 0

fc167d1b-045a-4057-936d-4862644af1f3.npz
  Frames: 840 | Pred idx:   92 (p=0.9765)
  GT positives: [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91] ...
  Hit (±0)? YES | Dist→nearest +ve: 0

f8039e25-4652-440c-9476-b425f3fccb22.npz
  Frames: 840 | Pred idx:   56 (p=0.7737)
  GT positives: [52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] ...
  Hit (±0)? YES | Dist→nearest +ve: 0

d5c3cfee-53ac-4021-8c1b-098c189f630e.npz
  Frames: 840 | Pred idx:  626 (p=0.9527)
  GT positives: [20, 21, 22, 23, 24, 25, 164, 165, 166, 167, 168, 169] ...
  Hit (±0)? YES | Dist→nearest +ve: 0

d5471cfd-6090-4d42-9a95-67ccbfbf612e.npz
  Frames: 840 | Pred idx:   42 (p=0.9714)
  GT positives: [42, 43, 44, 45, 46, 47, 48, 176, 177, 178, 179, 180] ...
  Hit (±0)? YES | Dist→nearest +ve: 0

dc