In [20]:
import os, random, math, numpy as np
from collections import Counter
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

try:
    # torchvision v2 transforms (tensor-native)
    from torchvision.transforms import v2 as T
except:
    # fallback to classic, but v2 is recommended
    import torchvision.transforms as T
    
# --- put these at top-level (not inside a function/cell) ---
import torch
import torch.nn as nn

class AddGaussianNoise(nn.Module):
    def __init__(self, std=0.03):
        super().__init__()
        self.std = float(std)
    def forward(self, x):
        return x + torch.randn_like(x) * self.std if self.std > 0 else x

class NoOp(nn.Module):
    def forward(self, x):
        return x


from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm

# ----- Paths (edit these) -----
TRAIN_DIR = "D:/dataset/npz_80_tiny"                   # reduced, balanced-ish for training
VAL_DIR   = "D:/dataset/converted_classifier_npz_compact"  # full-length compacts for validation/inference
SAVE_PATH = "D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_BYOL.pth"

# ----- Splits (edit indices as you like) -----
train_files = sorted([f for f in os.listdir(TRAIN_DIR) if f.endswith(".npz")])[:210]
val_files   = sorted([f for f in os.listdir(VAL_DIR)   if f.endswith(".npz")])[210:255]

# ----- Training hyperparams -----
IMAGE_SIZE = 224
BATCH_SIZE = 32
VAL_BATCH  = 64
EPOCHS     = 20
PATIENCE   = 5
LR_HEAD    = 1e-3     # warmup (head-only)
LR_ALL     = 1e-4     # full fine-tune
WEIGHT_DEC = 1e-4
SEED       = 42

device = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)


<torch._C.Generator at 0x21761a987b0>

In [3]:
def _letterbox_to_square(x: torch.Tensor, size: int = 224) -> torch.Tensor:
    """
    x: (B,1,H,W) or (1,H,W) in [0,1] -> pad to square (keep aspect) -> resize to (..,1,size,size)
    """
    is_batched = (x.dim() == 4)
    if not is_batched:
        x = x.unsqueeze(0)  # (1,1,H,W)
    _, _, H, W = x.shape
    s = max(H, W)
    pad_h = (s - H) // 2
    pad_w = (s - W) // 2
    x = F.pad(x, (pad_w, s - W - pad_w, pad_h, s - H - pad_h))  # L,R,T,B
    x = F.interpolate(x, size=(size, size), mode="bilinear", align_corners=False)
    return x if is_batched else x.squeeze(0)

In [4]:
class PreloadedNPZFrameDataset(Dataset):
    def __init__(self,
                 npz_dir: str,
                 files: list[str],
                 binary: bool = True,
                 out_size: int = 224,
                 resize_mode: str = "letterbox",
                 dtype: torch.dtype = torch.float16):
        self.binary = binary
        self.out_size = out_size
        self.resize_mode = resize_mode
        self.dtype = dtype

        imgs_all, labels_all, pids_all = [], [], []   # <-- NEW

        for f in files:
            path = os.path.join(npz_dir, f)
            case = np.load(path, allow_pickle=True)
            imgs = case["image"]                    # (T,H,W) uint8
            y    = case["label"].astype(np.int64)   # (T,)
            if binary:
                y[y == 2] = 1

            # patient ID per frame: prefer uuid inside the npz; fallback to filename stem
            pid = str(case["uuid"]) if "uuid" in case else os.path.splitext(f)[0]
            tframes = imgs.shape[0]
            pids_all.append(torch.full((tframes,), hash(pid) & 0x7fffffff, dtype=torch.int64))  # stable int pid

            t = torch.from_numpy(imgs).unsqueeze(1).float() / 255.0  # (T,1,H,W) [0,1]
            if resize_mode == "letterbox":
                out_frames = [ _letterbox_to_square(fr, size=out_size) for fr in t ]
                t = torch.stack(out_frames, dim=0)  # (T,1,S,S)
            else:
                t = F.interpolate(t, size=(out_size, out_size), mode="bilinear", align_corners=False)

            t = (t - 0.5) / 0.5  # [-1,1]
            imgs_all.append(t.to(dtype=dtype).cpu())
            labels_all.append(torch.from_numpy(y).long())

        self.images = torch.cat(imgs_all, dim=0)            # (N,1,S,S)
        self.labels = torch.cat(labels_all, dim=0)          # (N,)
        self.patient_ids = torch.cat(pids_all, dim=0)       # (N,)  <-- NEW

    def __len__(self):
        return self.labels.numel()

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]


In [6]:
import math, random
from collections import defaultdict
from torch.utils.data import Sampler

class PatientAwareBalancedBatchSampler(Sampler):
    """
    50/50 pos/neg batches that also mix patients.
    - Uses every positive once per epoch (finite).
    - Negatives are sampled to match.
    - frames_per_patient_side controls per-patient contribution per side.

    labels:       list/1D tensor of {0,1}
    patient_ids:  list/1D tensor of ints (aligned with labels)
    batch_size:   even
    frames_per_patient_side: e.g., 2 -> up to 2 pos from a patient and up to 2 neg from a patient per batch
    """
    def __init__(self, labels, patient_ids, batch_size=32, frames_per_patient_side=2, pos_label=1, seed=42):
        assert batch_size % 2 == 0 and batch_size >= 2
        self.labels = list(map(int, labels))
        self.pids   = list(map(int, patient_ids))
        self.half   = batch_size // 2
        self.k      = max(1, int(frames_per_patient_side))
        self.pos_label = pos_label
        self.rng = random.Random(seed)

        # Build per-patient pools
        pos_by_pid = defaultdict(list)
        neg_by_pid = defaultdict(list)
        for i, (y, pid) in enumerate(zip(self.labels, self.pids)):
            if y == pos_label: pos_by_pid[pid].append(i)
            elif y == 0:       neg_by_pid[pid].append(i)

        if not pos_by_pid or not neg_by_pid:
            raise ValueError("Need patients with both pos and neg frames.")

        self.pos_by_pid = {pid: idxs for pid, idxs in pos_by_pid.items() if idxs}
        self.neg_by_pid = {pid: idxs for pid, idxs in neg_by_pid.items() if idxs}
        self.pos_patients = list(self.pos_by_pid.keys())
        self.neg_patients = list(self.neg_by_pid.keys())

        # Steps per epoch: exhaust positives once
        self.num_pos = sum(len(v) for v in self.pos_by_pid.values())
        self.steps   = math.ceil(self.num_pos / self.half)

    def __len__(self):
        return self.steps

    def _multi_patient_take(self, pools_by_pid, patient_ids, need):
        """
        Take 'need' indices by visiting many patients, up to self.k per patient.
        Refill per-patient shuffled pools when everyone is exhausted.
        """
        rng = self.rng
        out = []
        # Build shuffled per-patient pools
        active = {pid: rng.sample(idxs, k=len(idxs)) for pid, idxs in pools_by_pid.items()}
        ring = patient_ids[:]
        rng.shuffle(ring)
        head = 0

        while len(out) < need:
            if head >= len(ring):
                head = 0
                rng.shuffle(ring)
            pid = ring[head]; head += 1
            pool = active.get(pid, [])
            if not pool:
                continue
            take = min(self.k, len(pool), need - len(out))
            out.extend(pool[:take])
            active[pid] = pool[take:]

            # if everyone empty, refill
            if all(len(v) == 0 for v in active.values()):
                active = {pid: rng.sample(pools_by_pid[pid], k=len(pools_by_pid[pid])) for pid in pools_by_pid}

        return out

    def __iter__(self):
        rng = self.rng

        # Flat list of all positives, shuffled; ensure enough for steps*half
        all_pos = []
        for pid, idxs in self.pos_by_pid.items():
            all_pos.extend(idxs)
        rng.shuffle(all_pos)

        total_pos_needed = self.steps * self.half
        if len(all_pos) < total_pos_needed:
            reps = math.ceil(total_pos_needed / len(all_pos))
            all_pos = (all_pos * reps)[:total_pos_needed]
        else:
            all_pos = all_pos[:total_pos_needed]

        pos_cursor = 0

        for _ in range(self.steps):
            # Positives: take next chunk then redistribute to mix patients (cap k/patient)
            pos_slice = all_pos[pos_cursor: pos_cursor + self.half]
            pos_cursor += self.half

            # Bucket that slice by patient
            by_pid = defaultdict(list)
            for i in pos_slice:
                by_pid[self.pids[i]].append(i)

            # Take up to k per patient, top-up if needed
            p_batch = []
            pids_shuf = list(by_pid.keys())
            rng.shuffle(pids_shuf)
            for pid in pids_shuf:
                take = min(self.k, len(by_pid[pid]), self.half - len(p_batch))
                if take > 0:
                    p_batch.extend(by_pid[pid][:take])
                if len(p_batch) == self.half: break
            if len(p_batch) < self.half:
                # top up from leftovers in this slice
                extras = []
                for pid in pids_shuf:
                    extras.extend(by_pid[pid][self.k:])
                p_batch = (p_batch + extras)[:self.half]

            # Negatives: sample across many patients (cap k/patient)
            n_batch = self._multi_patient_take(self.neg_by_pid, self.neg_patients, self.half)

            batch = p_batch + n_batch
            rng.shuffle(batch)
            yield batch


In [7]:
import torchvision.models as models

class FrameClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        try:
            backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        except:
            backbone = models.resnet50(pretrained=True)

        # replace first conv (3→1 channels)
        old = backbone.conv1  # (64,3,7,7)
        new = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            new.weight.copy_(old.weight.mean(dim=1, keepdim=True))  # average RGB to gray
        backbone.conv1 = new

        # replace classifier head
        backbone.fc = nn.Linear(backbone.fc.in_features, num_classes)

        self.backbone = backbone

    def forward(self, x):
        return self.backbone(x)


In [8]:

criterion = nn.CrossEntropyLoss(label_smoothing=0.05).to(device)
# criterion = FocalLoss(alpha=(1.0, 1.3), gamma=2.0).to(device)


In [9]:
# SAFE, ULTRASOUND-FRIENDLY AUGS
try:
    from torchvision.transforms import v2 as T
    _HAS_V2 = True
except:
    import torchvision.transforms as T
    _HAS_V2 = False

train_tf = T.Compose([
    T.RandomHorizontalFlip(p=0.3),                 # common probe orientation change
    T.RandomVerticalFlip(p=0.1),                   # rarer; OK but small prob
    T.RandomRotation(degrees=15),                  # small tilt
    T.RandomAffine(degrees=0, translate=(0.08, 0.08), scale=(0.9, 1.1)),  # ±8% shift, ±10% zoom

    # intensity realism (no hue/sat for grayscale)
    T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0))], p=0.2),
    T.RandomApply([T.Lambda(lambda x: x + 0.03 * torch.randn_like(x))], p=0.2),
    # NOTE: no Normalize here (your dataset already did (t-0.5)/0.5 → [-1,1])
])

# validation: true no-op
val_tf = T.Compose([T.Identity()]) if _HAS_V2 else T.Compose([T.Lambda(lambda x: x)])


In [10]:
from torch.utils.data import Dataset

class WithTransform(Dataset):
    def __init__(self, base: Dataset, transform=None):
        self.base = base
        self.transform = transform
    def __len__(self):
        return len(self.base)
    def __getitem__(self, idx):
        x, y = self.base[idx]        # x: (1,S,S), currently FP16 and already in [-1,1]
        x = x.to(dtype=torch.float32)
        if self.transform is not None:
            x = self.transform(x)
        return x, y


In [11]:
# base datasets (already normalize to [-1,1] inside PreloadedNPZFrameDataset)
train_ds_base = PreloadedNPZFrameDataset(TRAIN_DIR, train_files, binary=True, out_size=IMAGE_SIZE)
val_ds_base   = PreloadedNPZFrameDataset(VAL_DIR,   val_files,   binary=True, out_size=IMAGE_SIZE)


In [12]:
train_sampler = PatientAwareBalancedBatchSampler(
    labels=train_ds_base.labels.tolist(),
    patient_ids=train_ds_base.patient_ids.tolist(),
    batch_size=BATCH_SIZE,
    frames_per_patient_side=2,   # try 1–2; increase if you want more patients per batch
    pos_label=1,
    seed=SEED
)


In [13]:
train_ds = WithTransform(train_ds_base, transform=train_tf)   # your safe train_tf
val_ds   = WithTransform(val_ds_base,   transform=val_tf)     # no-op for val


In [14]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_ds,
    batch_sampler=train_sampler,    # NOTE: batch_sampler (do NOT also pass batch_size)
    num_workers=0,
    pin_memory=True
)
val_loader = DataLoader(
    val_ds,
    batch_size=VAL_BATCH,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)


#### byol pretraining

In [15]:
# ===== BYOL pretrain for grayscale ResNet50 (Windows-safe, AMP, early stopping) =====
import os, math, numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from torch import amp

torch.backends.cudnn.benchmark = True

# ---------------- small, picklable helper transforms ----------------
class AddGaussianNoise(nn.Module):
    def __init__(self, std=0.03):
        super().__init__()
        self.std = float(std)
    def forward(self, x):
        if self.std <= 0: return x
        return x + torch.randn_like(x) * self.std

class TwoCropsTransform(nn.Module):
    """Return two independently augmented views."""
    def __init__(self, base):
        super().__init__()
        self.base = base
    def forward(self, x):
        v1 = self.base(x)
        v2 = self.base(x)
        return v1, v2

# ---------------- SSL-friendly augs (operate on [-1,1] tensors) ----------------
try:
    from torchvision.transforms import v2 as T
    _V2 = True
except:
    import torchvision.transforms as T
    _V2 = False

ssl_base_aug = T.Compose([
    T.RandomHorizontalFlip(p=0.3),
    T.RandomVerticalFlip(p=0.1),
    T.RandomRotation(degrees=15),
    T.RandomAffine(degrees=0, translate=(0.08,0.08), scale=(0.9,1.1)),
    T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1,1.0))], p=0.2),
    T.RandomApply([AddGaussianNoise(0.03)], p=0.2),
])
ssl_transform = TwoCropsTransform(ssl_base_aug)

# ---------------- Dataset wrapper: returns (view1, view2) ----------------
class SSLPairDataset(Dataset):
    def __init__(self, base_ds: Dataset, transform):
        self.base = base_ds
        self.transform = transform
    def __len__(self): return len(self.base)
    def __getitem__(self, idx):
        x, _ = self.base[idx]          # x: (1,S,S) fp16 in [-1,1]
        x = x.to(torch.float32)        # augs in fp32
        v1, v2 = self.transform(x)
        return v1, v2

# ---------------- BYOL components ----------------
import torchvision.models as models
import copy

class Projector(nn.Module):
    # 2-layer MLP with BN (as common for BYOL/SimSiam)
    def __init__(self, in_dim, hidden=2048, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden, bias=False),
            nn.BatchNorm1d(hidden),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, out_dim, bias=True)
        )
    def forward(self, x): return self.net(x)

class Predictor(nn.Module):
    def __init__(self, in_dim=256, hidden=256, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden, bias=False),
            nn.BatchNorm1d(hidden),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, out_dim)
        )
    def forward(self, x): return self.net(x)

def build_resnet50_gray(imagenet_init=True):
    # ResNet50 with 1-channel conv1 and fc=Identity (encoder only)
    try:
        backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if imagenet_init else None)
    except:
        backbone = models.resnet50(pretrained=imagenet_init)
    old = backbone.conv1
    new = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    with torch.no_grad():
        if old.weight.shape[1] == 3:
            new.weight.copy_(old.weight.mean(dim=1, keepdim=True))
        else:
            nn.init.kaiming_normal_(new.weight, mode="fan_out", nonlinearity="relu")
    backbone.conv1 = new
    feat_dim = backbone.fc.in_features
    backbone.fc = nn.Identity()
    return backbone, feat_dim

class BYOL(nn.Module):
    """
    Online: encoder + projector + predictor
    Target: EMA(encoder + projector)  (no predictor)
    Loss: 0.5 * [ 2 - 2*cos(p1, t2) ] + 0.5 * [ 2 - 2*cos(p2, t1) ]
    """
    def __init__(self, imagenet_init=True, tau_base=0.996):
        super().__init__()
        enc_online, feat_dim = build_resnet50_gray(imagenet_init=imagenet_init)
        self.online_encoder = enc_online
        self.online_proj    = Projector(feat_dim, hidden=2048, out_dim=256)
        self.predictor      = Predictor(256, hidden=256, out_dim=256)

        # target = EMA copies
        self.target_encoder = copy.deepcopy(self.online_encoder)
        self.target_proj    = copy.deepcopy(self.online_proj)
        for p in self.target_encoder.parameters(): p.requires_grad = False
        for p in self.target_proj.parameters():    p.requires_grad = False

        self.tau_base = float(tau_base)  # base EMA momentum
        self.tau = self.tau_base         # will be scheduled

    @torch.no_grad()
    def ema_update(self):
        # target = tau * target + (1 - tau) * online
        for o, t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            t.data.mul_(self.tau).add_(o.data, alpha=(1.0 - self.tau))
        for o, t in zip(self.online_proj.parameters(), self.target_proj.parameters()):
            t.data.mul_(self.tau).add_(o.data, alpha=(1.0 - self.tau))

    def forward(self, v1, v2):
        # Online branch
        f1 = self.online_encoder(v1); z1 = self.online_proj(f1); p1 = self.predictor(z1)
        f2 = self.online_encoder(v2); z2 = self.online_proj(f2); p2 = self.predictor(z2)
        # Target branch (no grad)
        with torch.no_grad():
            t1 = self.target_proj(self.target_encoder(v1))
            t2 = self.target_proj(self.target_encoder(v2))
        # cosine loss (0 is best)
        def cos_loss(p, z):
            p = F.normalize(p, dim=1); z = F.normalize(z, dim=1)
            return 2.0 - 2.0 * (p * z).sum(dim=1).mean()
        loss = 0.5 * cos_loss(p1, t2) + 0.5 * cos_loss(p2, t1)
        return loss

# ---------------- Build unlabeled dataset ----------------
# Uses your preloaded train set: frames already letterboxed & normalized to [-1,1]
ssl_base = PreloadedNPZFrameDataset(TRAIN_DIR, train_files, binary=True, out_size=IMAGE_SIZE)
ssl_ds   = SSLPairDataset(ssl_base, transform=ssl_transform)

# ---------------- Hyperparams ----------------
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
SSL_BATCH    = 128
SSL_EPOCHS   = 30
LR_BASE      = 0.05
WEIGHT_DEC   = 1e-4

# Early stopping / checkpoints
PATIENCE     = 7
MIN_DELTA    = 1e-4
SAVE_DIR     = "D:/acouslic-ai-cse4622/saved_weights"
ENC_BEST     = os.path.join(SAVE_DIR, "byol_encoder_best.pth")
FULL_BEST    = os.path.join(SAVE_DIR, "byol_full_best.pth")
os.makedirs(SAVE_DIR, exist_ok=True)

ssl_loader = DataLoader(
    ssl_ds, batch_size=SSL_BATCH, shuffle=True,
    num_workers=0, pin_memory=True, drop_last=True
)

byol = BYOL(imagenet_init=True, tau_base=0.996).to(DEVICE)
opt   = SGD(byol.parameters(), lr=LR_BASE, momentum=0.9, weight_decay=WEIGHT_DEC)
sched = CosineAnnealingLR(opt, T_max=SSL_EPOCHS)
scaler = amp.GradScaler(enabled=(DEVICE=="cuda"))

# Cosine schedule for EMA tau over total steps: tau -> closer to 1
total_steps = SSL_EPOCHS * len(ssl_loader)
cur_step = 0

best_loss = float("inf")
epochs_no_improve = 0
best_epoch = 0

for epoch in range(1, SSL_EPOCHS+1):
    byol.train()
    running = 0.0
    num_seen = 0

    for v1, v2 in tqdm(ssl_loader, desc=f"BYOL pretrain {epoch}/{SSL_EPOCHS}"):
        v1 = v1.to(DEVICE, non_blocking=True).float()
        v2 = v2.to(DEVICE, non_blocking=True).float()

        # Update EMA momentum (tau) with cosine schedule
        if total_steps > 0:
            # tau = 1 - (1 - tau_base) * (cos(pi * t/T) + 1)/2
            cos_term = (1 + math.cos(math.pi * cur_step / total_steps)) / 2.0
            byol.tau = 1.0 - (1.0 - byol.tau_base) * cos_term
        cur_step += 1

        opt.zero_grad(set_to_none=True)
        with amp.autocast(device_type="cuda", enabled=(DEVICE=="cuda")):
            loss = byol(v1, v2)

        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        # EMA update AFTER optimizer step
        with torch.no_grad():
            byol.ema_update()

        running += loss.item() * v1.size(0)
        num_seen += v1.size(0)

    sched.step()
    avg_loss = running / max(1, num_seen)
    print(f"Epoch {epoch}: BYOL loss {avg_loss:.4f} (lower is better; 0 is ideal)")

    # Early stopping on best (lower) loss
    if (best_loss - avg_loss) > MIN_DELTA:
        best_loss = avg_loss
        best_epoch = epoch
        epochs_no_improve = 0
        # Save best encoder and full checkpoint
        torch.save(byol.online_encoder.state_dict(), ENC_BEST)
        torch.save({
            "epoch": epoch,
            "model_state": byol.state_dict(),
            "optimizer_state": opt.state_dict(),
            "scheduler_state": sched.state_dict(),
            "scaler_state": scaler.state_dict(),
            "best_loss": best_loss,
        }, FULL_BEST)
        print(f"✅ Saved BEST (epoch {epoch}) | loss {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"No improvement ({epochs_no_improve}/{PATIENCE})")
        if epochs_no_improve >= PATIENCE:
            print("⏹ Early stopping triggered.")
            break

print(f"Best epoch: {best_epoch} | Best BYOL loss: {best_loss:.4f}")
print(f"Best encoder saved at: {ENC_BEST}")


BYOL pretrain 1/30: 100%|██████████| 131/131 [02:14<00:00,  1.03s/it]


Epoch 1: BYOL loss 0.7304 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 1) | loss 0.7304


BYOL pretrain 2/30: 100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


Epoch 2: BYOL loss 0.3747 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 2) | loss 0.3747


BYOL pretrain 3/30: 100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


Epoch 3: BYOL loss 0.2114 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 3) | loss 0.2114


BYOL pretrain 4/30: 100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


Epoch 4: BYOL loss 0.1342 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 4) | loss 0.1342


BYOL pretrain 5/30: 100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


Epoch 5: BYOL loss 0.0943 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 5) | loss 0.0943


BYOL pretrain 6/30: 100%|██████████| 131/131 [02:07<00:00,  1.03it/s]


Epoch 6: BYOL loss 0.0717 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 6) | loss 0.0717


BYOL pretrain 7/30: 100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


Epoch 7: BYOL loss 0.0574 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 7) | loss 0.0574


BYOL pretrain 8/30: 100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


Epoch 8: BYOL loss 0.0485 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 8) | loss 0.0485


BYOL pretrain 9/30: 100%|██████████| 131/131 [02:07<00:00,  1.03it/s]


Epoch 9: BYOL loss 0.0417 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 9) | loss 0.0417


BYOL pretrain 10/30: 100%|██████████| 131/131 [02:07<00:00,  1.03it/s]


Epoch 10: BYOL loss 0.0369 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 10) | loss 0.0369


BYOL pretrain 11/30: 100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


Epoch 11: BYOL loss 0.0338 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 11) | loss 0.0338


BYOL pretrain 12/30: 100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


Epoch 12: BYOL loss 0.0314 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 12) | loss 0.0314


BYOL pretrain 13/30: 100%|██████████| 131/131 [02:07<00:00,  1.02it/s]


Epoch 13: BYOL loss 0.0291 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 13) | loss 0.0291


BYOL pretrain 14/30: 100%|██████████| 131/131 [02:07<00:00,  1.03it/s]


Epoch 14: BYOL loss 0.0276 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 14) | loss 0.0276


BYOL pretrain 15/30: 100%|██████████| 131/131 [02:08<00:00,  1.02it/s]


Epoch 15: BYOL loss 0.0266 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 15) | loss 0.0266


BYOL pretrain 16/30: 100%|██████████| 131/131 [02:07<00:00,  1.02it/s]


Epoch 16: BYOL loss 0.0255 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 16) | loss 0.0255


BYOL pretrain 17/30: 100%|██████████| 131/131 [02:07<00:00,  1.03it/s]


Epoch 17: BYOL loss 0.0241 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 17) | loss 0.0241


BYOL pretrain 18/30: 100%|██████████| 131/131 [02:07<00:00,  1.03it/s]


Epoch 18: BYOL loss 0.0234 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 18) | loss 0.0234


BYOL pretrain 19/30: 100%|██████████| 131/131 [02:06<00:00,  1.04it/s]


Epoch 19: BYOL loss 0.0228 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 19) | loss 0.0228


BYOL pretrain 20/30: 100%|██████████| 131/131 [02:03<00:00,  1.06it/s]


Epoch 20: BYOL loss 0.0224 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 20) | loss 0.0224


BYOL pretrain 21/30: 100%|██████████| 131/131 [02:05<00:00,  1.04it/s]


Epoch 21: BYOL loss 0.0217 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 21) | loss 0.0217


BYOL pretrain 22/30: 100%|██████████| 131/131 [02:05<00:00,  1.04it/s]


Epoch 22: BYOL loss 0.0212 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 22) | loss 0.0212


BYOL pretrain 23/30: 100%|██████████| 131/131 [02:05<00:00,  1.04it/s]


Epoch 23: BYOL loss 0.0209 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 23) | loss 0.0209


BYOL pretrain 24/30: 100%|██████████| 131/131 [02:05<00:00,  1.04it/s]


Epoch 24: BYOL loss 0.0207 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 24) | loss 0.0207


BYOL pretrain 25/30: 100%|██████████| 131/131 [02:05<00:00,  1.04it/s]


Epoch 25: BYOL loss 0.0205 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 25) | loss 0.0205


BYOL pretrain 26/30: 100%|██████████| 131/131 [02:05<00:00,  1.04it/s]


Epoch 26: BYOL loss 0.0203 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 26) | loss 0.0203


BYOL pretrain 27/30: 100%|██████████| 131/131 [02:05<00:00,  1.04it/s]


Epoch 27: BYOL loss 0.0203 (lower is better; 0 is ideal)
No improvement (1/7)


BYOL pretrain 28/30: 100%|██████████| 131/131 [02:06<00:00,  1.04it/s]


Epoch 28: BYOL loss 0.0200 (lower is better; 0 is ideal)
✅ Saved BEST (epoch 28) | loss 0.0200


BYOL pretrain 29/30: 100%|██████████| 131/131 [02:04<00:00,  1.06it/s]


Epoch 29: BYOL loss 0.0200 (lower is better; 0 is ideal)
No improvement (1/7)


BYOL pretrain 30/30: 100%|██████████| 131/131 [02:02<00:00,  1.07it/s]

Epoch 30: BYOL loss 0.0201 (lower is better; 0 is ideal)
No improvement (2/7)
Best epoch: 28 | Best BYOL loss: 0.0200
Best encoder saved at: D:/acouslic-ai-cse4622/saved_weights\byol_encoder_best.pth





In [16]:
model = FrameClassifier(num_classes=2).to(device)

# --- load SimSiam encoder (from your SSL pretrain) ---
SSL_ENC_PATH = "D:/acouslic-ai-cse4622/saved_weights/byol_encoder_best.pth"
enc_state = torch.load(SSL_ENC_PATH, map_location=device)
missing, unexpected = model.backbone.load_state_dict(enc_state, strict=False)
print("Loaded BYOL encoder. missing:", missing, "unexpected:", unexpected)
# ------------------------------------------------------


Loaded BYOL encoder. missing: ['fc.weight', 'fc.bias'] unexpected: []


In [17]:
# loss (keep it simple first)
criterion = nn.CrossEntropyLoss(label_smoothing=0.05).to(device)

# head warmup: freeze backbone, train only the final FC
for p in model.backbone.parameters():
    p.requires_grad = False
for p in model.backbone.fc.parameters():
    p.requires_grad = True

opt = torch.optim.AdamW(model.backbone.fc.parameters(), lr=LR_HEAD, weight_decay=WEIGHT_DEC)

# (optional: modern AMP context later in the loop)
from torch import amp
scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda"))  # ok
# and use: with amp.autocast(device_type="cuda", enabled=(device=="cuda")):


  scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda"))  # ok


In [18]:
# load encoder-only checkpoint
enc_state = torch.load("D:/acouslic-ai-cse4622/saved_weights/byol_encoder_best.pth", map_location=device)

# apply to your FrameClassifier backbone
missing, unexpected = model.backbone.load_state_dict(enc_state, strict=False)
print("Loaded SimSiam encoder. missing:", missing, "unexpected:", unexpected)


Loaded SimSiam encoder. missing: ['fc.weight', 'fc.bias'] unexpected: []


In [19]:
import numpy as np
import torch.nn.functional as F
from sklearn.metrics import precision_recall_fscore_support, classification_report, confusion_matrix

@torch.no_grad()
def collect_logits_targets(model, loader):
    model.eval()
    all_logits, all_targets = [], []
    for x, y in loader:
        x = x.to(device, non_blocking=True).float()
        y = y.to(device, non_blocking=True)
        logits = model(x)
        all_logits.append(logits.cpu())
        all_targets.append(y.cpu())
    return torch.cat(all_logits), torch.cat(all_targets).numpy()

def find_best_threshold(logits, targets, goal="f1", target_precision=None):
    probs = F.softmax(logits, dim=1)[:, 1].numpy()
    y = np.array(targets)
    best_t, best_score = 0.5, -1.0
    for t in np.linspace(0.3, 0.9, 61):
        yhat = (probs >= t).astype(int)
        p, r, f1, _ = precision_recall_fscore_support(y, yhat, average="binary", zero_division=0)
        score = (r if (target_precision and p >= target_precision) else
                 f1 if goal == "f1" else p)
        if score > best_score:
            best_score, best_t = score, t
    return best_t


In [21]:
best_metric = -1.0
epochs_no_improve = 0

for epoch in range(1, EPOCHS + 1):
    # unfreeze backbone after warmup epoch 1
    if epoch == 2 and any(not p.requires_grad for p in model.backbone.parameters()):
        for p in model.backbone.parameters():
            p.requires_grad = True
        opt = torch.optim.AdamW(model.parameters(), lr=LR_ALL, weight_decay=WEIGHT_DEC)

    model.train()
    running_loss = 0.0

    for x, y in tqdm(train_loader, desc=f"Train {epoch}/{EPOCHS}"):
        x = x.to(device, non_blocking=True).float()
        y = y.to(device, non_blocking=True)

        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device == "cuda")):
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        running_loss += loss.item() * x.size(0)

    train_loss = running_loss / len(train_ds)

    # validation: pick threshold this epoch
    val_logits, val_targets = collect_logits_targets(model, val_loader)
    t_star = find_best_threshold(val_logits, val_targets, goal="f1")  # or target_precision=0.80

    probs = F.softmax(val_logits, dim=1)[:, 1].numpy()
    pred  = (probs >= t_star).astype(int)

    p, r, f1, _ = precision_recall_fscore_support(val_targets, pred, average="binary", zero_division=0)
    acc = (pred == val_targets).mean()

    print(f"\nEpoch {epoch} | train loss {train_loss:.4f} | val acc {acc*100:.2f}% | P/R/F1 {p:.3f}/{r:.3f}/{f1:.3f} | thr* {t_star:.2f}")

    # Save the best by F1 (positive class)
    metric = f1
    if metric > best_metric:
        best_metric = metric
        epochs_no_improve = 0
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"✅ Saved best model (val F1={best_metric:.3f} @ thr={t_star:.2f})")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print("⏹ Early stopping.")
            break

print(f"Best F1: {best_metric:.3f} | Weights saved to: {SAVE_PATH}")


  with torch.cuda.amp.autocast(enabled=(device == "cuda")):
Train 1/20: 100%|██████████| 346/346 [00:23<00:00, 14.81it/s]



Epoch 1 | train loss 0.2385 | val acc 97.04% | P/R/F1 0.405/0.387/0.396 | thr* 0.89
✅ Saved best model (val F1=0.396 @ thr=0.89)


  with torch.cuda.amp.autocast(enabled=(device == "cuda")):
Train 2/20: 100%|██████████| 346/346 [00:37<00:00,  9.14it/s]



Epoch 2 | train loss 0.1732 | val acc 96.17% | P/R/F1 0.382/0.854/0.528 | thr* 0.90
✅ Saved best model (val F1=0.528 @ thr=0.90)


  with torch.cuda.amp.autocast(enabled=(device == "cuda")):
Train 3/20: 100%|██████████| 346/346 [00:36<00:00,  9.52it/s]



Epoch 3 | train loss 0.1415 | val acc 97.98% | P/R/F1 0.589/0.638/0.613 | thr* 0.87
✅ Saved best model (val F1=0.613 @ thr=0.87)


  with torch.cuda.amp.autocast(enabled=(device == "cuda")):
Train 4/20: 100%|██████████| 346/346 [00:36<00:00,  9.55it/s]



Epoch 4 | train loss 0.1296 | val acc 97.18% | P/R/F1 0.461/0.736/0.567 | thr* 0.90


  with torch.cuda.amp.autocast(enabled=(device == "cuda")):
Train 5/20: 100%|██████████| 346/346 [00:36<00:00,  9.53it/s]



Epoch 5 | train loss 0.1263 | val acc 97.97% | P/R/F1 0.586/0.640/0.612 | thr* 0.89


  with torch.cuda.amp.autocast(enabled=(device == "cuda")):
Train 6/20: 100%|██████████| 346/346 [00:36<00:00,  9.61it/s]



Epoch 6 | train loss 0.1227 | val acc 97.76% | P/R/F1 0.546/0.630/0.585 | thr* 0.90


  with torch.cuda.amp.autocast(enabled=(device == "cuda")):
Train 7/20: 100%|██████████| 346/346 [00:36<00:00,  9.52it/s]



Epoch 7 | train loss 0.1179 | val acc 95.73% | P/R/F1 0.358/0.887/0.510 | thr* 0.90


  with torch.cuda.amp.autocast(enabled=(device == "cuda")):
Train 8/20: 100%|██████████| 346/346 [00:36<00:00,  9.54it/s]



Epoch 8 | train loss 0.1155 | val acc 96.49% | P/R/F1 0.407/0.867/0.554 | thr* 0.90
⏹ Early stopping.
Best F1: 0.613 | Weights saved to: D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_BYOL.pth


In [26]:
import os, numpy as np, torch
import torch.nn.functional as F
from torch import amp
from tqdm import tqdm

# -----------------------------
# Config (edit these)
# -----------------------------
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH  = "D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_BYOL.pth"
TEST_DIR    = "D:/dataset/converted_classifier_npz_compact"   # full 840-frame npz files
START_IDX   = 255   # where your test split starts
NUM_CASES   = 45
IMAGE_SIZE  = 224
MA_WINDOW   = 9 # temporal smoothing (moving average)
TOPK        = 3     # use top-1 for "best frame"
THRESH      = 0.87  # or e.g. 0.6 to require minimum smoothed prob
TOLERANCE   = 3     # 0 = exact frame; try 3 for ±3 frames

# -----------------------------
# Utils (match your training preprocessing)
# -----------------------------
def letterbox_to_square_tensor(x: torch.Tensor, size=IMAGE_SIZE) -> torch.Tensor:
    """x: (B,1,H,W) or (1,H,W) in [0,1] -> square pad -> resize to (..,1,size,size)"""
    is_batched = (x.dim() == 4)
    if not is_batched: x = x.unsqueeze(0)
    _, _, H, W = x.shape
    s = max(H, W)
    pad_h = (s - H) // 2
    pad_w = (s - W) // 2
    x = F.pad(x, (pad_w, s-W-pad_w, pad_h, s-H-pad_h))
    x = F.interpolate(x, size=(size, size), mode="bilinear", align_corners=False)
    return x if is_batched else x.squeeze(0)

def moving_average(a: np.ndarray, w: int = 7):
    if w <= 1: return a
    k = np.ones(w, dtype=np.float32) / w
    return np.convolve(a, k, mode="same")

@torch.no_grad()
def predict_probs_over_frames(npz_path, model, batch_size=128):
    d = np.load(npz_path, mmap_mode="r")
    frames = d["image"]  # (T,H,W) uint8
    Tn = len(frames)
    probs = np.zeros(Tn, dtype=np.float32)
    off = 0
    while off < Tn:
        chunk = frames[off:off+batch_size]
        x = torch.from_numpy(chunk).unsqueeze(1).float() / 255.0      # (B,1,H,W) [0,1]
        x = letterbox_to_square_tensor(x, size=IMAGE_SIZE)            # (B,1,S,S)
        x = (x - 0.5) / 0.5                                           # [-1,1]
        x = x.to(DEVICE)

        with amp.autocast(device_type="cuda", enabled=(DEVICE=="cuda")):
            logits = model(x)
            p = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
        probs[off:off+len(p)] = p
        off += len(p)
    return probs

@torch.no_grad()
def pick_best_frame(npz_path, model, ma_window=7, topk=1, thresh=None):
    """Returns best_idx (int), best_score (float), idx_topk (np.ndarray), raw_probs, smoothed_probs."""
    probs = predict_probs_over_frames(npz_path, model)
    sm = moving_average(probs, w=ma_window)
    idx_sorted = np.argsort(-sm)
    idx_topk = idx_sorted[:topk]
    if thresh is not None:
        idx_topk = np.array([i for i in idx_topk if sm[i] >= thresh], dtype=int)
        if len(idx_topk) == 0:
            idx_topk = np.array([int(np.argmax(sm))], dtype=int)
    best_idx = int(idx_topk[0])
    best_score = float(sm[best_idx])
    return best_idx, best_score, idx_topk, probs, sm

def load_binary_labels(npz_path):
    y = np.load(npz_path, mmap_mode="r")["label"].astype(np.int64)
    y[y == 2] = 1
    return y

# -----------------------------
# Load model
# -----------------------------
model = FrameClassifier(num_classes=2).to(DEVICE)
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
state = ckpt.get("model_state", ckpt)  # support plain state_dict too
model.load_state_dict(state)
model.eval()

# -----------------------------
# Pick 45 test files and evaluate
# -----------------------------
files = sorted([f for f in os.listdir(TEST_DIR) if f.endswith(".npz")])
test_files = files[START_IDX:START_IDX+NUM_CASES]
assert len(test_files) > 0, "No test files selected."

correct = 0
details = []

print(f"Evaluating {len(test_files)} cases (tolerance=±{TOLERANCE})...\n")
for f in tqdm(test_files):
    path = os.path.join(TEST_DIR, f)
    y = load_binary_labels(path)
    pos_idx = np.where(y == 1)[0]  # ground truth positive frames

    best_idx, best_score, idx_topk, probs, sm = pick_best_frame(
        path, model, ma_window=MA_WINDOW, topk=TOPK, thresh=THRESH
    )

    if len(pos_idx) == 0:
        hit = (y[best_idx] == 0)  # no positives: "correct" if best is background
        dist = 0
    else:
        # exact or ±tolerance match
        dist = int(np.min(np.abs(pos_idx - best_idx)))
        hit = (dist <= TOLERANCE)

    correct += int(hit)
    details.append((f, best_idx, best_score, dist, hit))

# -----------------------------
# Summary
# -----------------------------
print("\n====== Summary ======")
print(f"Correct best-frame picks: {correct}/{len(test_files)} ({100.0*correct/len(test_files):.1f}%)")
print(f"(tolerance = ±{TOLERANCE}, MA={MA_WINDOW}, topk={TOPK}, thresh={THRESH})")

# Optional: show a few per-case lines
for f, bi, bs, d, h in details[:45]:
    print(f"{f}: best={bi:4d}  prob={bs:.3f}  dist_to_pos={d:3d}  {'HIT' if h else 'MISS'}")


Evaluating 45 cases (tolerance=±3)...



100%|██████████| 45/45 [00:43<00:00,  1.04it/s]


Correct best-frame picks: 38/45 (84.4%)
(tolerance = ±3, MA=9, topk=3, thresh=0.87)
d42fb920-5df1-4341-93df-480c17355e44.npz: best= 803  prob=0.948  dist_to_pos=  0  HIT
d5471cfd-6090-4d42-9a95-67ccbfbf612e.npz: best= 179  prob=0.936  dist_to_pos=  0  HIT
d571d4e1-ff80-44b9-a481-07961c6a1208.npz: best=  47  prob=0.963  dist_to_pos=  0  HIT
d5c3cfee-53ac-4021-8c1b-098c189f630e.npz: best= 623  prob=0.953  dist_to_pos=  0  HIT
d5f8c859-de93-4a50-b324-1ae4ad0267d4.npz: best=  68  prob=0.948  dist_to_pos=  0  HIT
d624338f-d09b-4bda-bbc3-3fa417015d6b.npz: best=  77  prob=0.982  dist_to_pos=  0  HIT
d77b6ece-da17-4f88-818c-0c7340b3e54f.npz: best=  58  prob=0.976  dist_to_pos=  0  HIT
d812091a-3635-4d51-9290-6adb3aa8681e.npz: best=  35  prob=0.925  dist_to_pos=  0  HIT
d8c3665a-4dc3-40ce-b716-f30aab365332.npz: best= 179  prob=0.728  dist_to_pos=412  MISS
db9d468d-cb20-4d5e-b059-31728f5950e6.npz: best= 333  prob=0.973  dist_to_pos=  0  HIT
dc0cbbdf-e4bb-4de5-958a-10576129e440.npz: best=  50  p




In [None]:
import os, numpy as np, torch
import torch.nn.functional as F
from torch import amp
from tqdm import tqdm

# ---------------------------------
# Config (edit these)
# ---------------------------------
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH   = "D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_afra.pth"
TEST_DIR     = "D:/dataset/converted_classifier_npz_compact"
START_IDX    = 255
NUM_CASES    = 45
IMAGE_SIZE   = 224
BATCH_SIZE   = 128
MA_WINDOW    = 11      # try 9–11; 11 is a bit smoother
# Hysteresis thresholds for plateau selection
HYST_HIGH    = 0.90    # set to your val-tuned threshold
HYST_LOW     = 0.80    # keep ~0.1 below HIGH
MIN_RUN_LEN  = 5       # min consecutive frames to count as a plateau
ALLOW_NONE   = False   # True -> return -1 if no frame passes thresholds
TOPK         = 1       # kept for compatibility; plateau selection returns 1 index
EVAL_TOLS    = (0, 3, 5)  # show accuracy at ±0/±3/±5

torch.backends.cudnn.benchmark = True

# ---------------------------------
# Utils (match training preprocessing)
# ---------------------------------
def letterbox_to_square_tensor(x: torch.Tensor, size=IMAGE_SIZE) -> torch.Tensor:
    """x: (B,1,H,W) or (1,H,W) in [0,1] -> square pad -> resize to (..,1,size,size)"""
    is_batched = (x.dim() == 4)
    if not is_batched: x = x.unsqueeze(0)
    _, _, H, W = x.shape
    s = max(H, W)
    pad_h = (s - H) // 2
    pad_w = (s - W) // 2
    x = F.pad(x, (pad_w, s - W - pad_w, pad_h, s - H - pad_h))  # L,R,T,B
    x = F.interpolate(x, size=(size, size), mode="bilinear", align_corners=False)
    return x if is_batched else x.squeeze(0)

def moving_average_safe(a: np.ndarray, w: int = 7) -> np.ndarray:
    """Edge-safe moving average using edge padding."""
    if w <= 1: return a
    pad = w // 2
    ap = np.pad(a, (pad, pad), mode="edge")
    k = np.ones(w, dtype=np.float32) / w
    return np.convolve(ap, k, mode="valid")

def _segments_above_threshold(a: np.ndarray, low: float):
    """Return list of (start,end) inclusive segments where a >= low."""
    on = a >= low
    segs = []
    i, n = 0, len(a)
    while i < n:
        if not on[i]:
            i += 1
            continue
        j = i
        while j + 1 < n and on[j + 1]:
            j += 1
        segs.append((i, j))
        i = j + 1
    return segs

def pick_best_index_from_plateaus(smoothed: np.ndarray,
                                  high: float,
                                  low: float,
                                  min_len: int) -> int:
    """
    Prefer sustained plateaus:
      1) segments where smoothed >= low
      2) keep those whose max >= high and length >= min_len
      3) choose the one with highest peak (break ties by longer length)
      4) return the segment center index
    Fallback: argmax if none qualify.
    """
    segs = _segments_above_threshold(smoothed, low)
    cand = []
    for (s, e) in segs:
        seg = smoothed[s:e+1]
        if seg.max() >= high and (e - s + 1) >= min_len:
            cand.append((s, e, seg.max()))
    if cand:
        cand.sort(key=lambda t: (t[2], (t[1] - t[0] + 1)), reverse=True)
        s, e, _ = cand[0]
        return (s + e) // 2
    return int(np.argmax(smoothed))

@torch.no_grad()
def predict_probs_over_frames(npz_path, model, batch_size=BATCH_SIZE):
    d = np.load(npz_path, mmap_mode="r")
    frames = d["image"]  # (T,H,W) uint8
    Tn = len(frames)
    probs = np.zeros(Tn, dtype=np.float32)
    off = 0
    while off < Tn:
        chunk = frames[off:off+batch_size]
        x = torch.from_numpy(chunk).unsqueeze(1).float() / 255.0      # (B,1,H,W) [0,1]
        x = letterbox_to_square_tensor(x, size=IMAGE_SIZE)            # (B,1,S,S)
        x = (x - 0.5) / 0.5                                           # [-1,1]
        x = x.to(DEVICE)

        with amp.autocast(device_type="cuda", enabled=(DEVICE == "cuda")):
            logits = model(x)
            p = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
        probs[off:off+len(p)] = p
        off += len(p)
    return probs

@torch.no_grad()
def pick_best_frame(npz_path, model,
                    ma_window=MA_WINDOW,
                    high=HYST_HIGH, low=HYST_LOW,
                    min_run=MIN_RUN_LEN,
                    allow_none=ALLOW_NONE):
    """
    Returns:
      best_idx (int), best_score (float), raw_probs (np.ndarray), smoothed_probs (np.ndarray)
      If allow_none and no plateau meets criteria, best_idx == -1
    """
    probs = predict_probs_over_frames(npz_path, model)
    sm = moving_average_safe(probs, w=ma_window)

    best_idx = pick_best_index_from_plateaus(sm, high=high, low=low, min_len=min_run)

    # Optionally allow "no suitable frame" if even argmax is below high
    if allow_none and sm[best_idx] < high:
        return -1, float(sm.max()), probs, sm

    return int(best_idx), float(sm[best_idx]), probs, sm

def load_binary_labels(npz_path):
    y = np.load(npz_path, mmap_mode="r")["label"].astype(np.int64)
    y[y == 2] = 1
    return y

# ---------------------------------
# Load model
# ---------------------------------
model = FrameClassifier(num_classes=2).to(DEVICE)
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
state = ckpt.get("model_state", ckpt)  # support plain state_dict too
model.load_state_dict(state)
model.eval()

# ---------------------------------
# Evaluate selected test files
# ---------------------------------
files = sorted([f for f in os.listdir(TEST_DIR) if f.endswith(".npz")])
test_files = files[START_IDX:START_IDX+NUM_CASES]
assert len(test_files) > 0, "No test files selected."

details = []  # (fname, best_idx, best_score, dist, hit@0, has_pos, bg_correct)

print(f"Evaluating {len(test_files)} cases (plateau: high={HYST_HIGH}, low={HYST_LOW}, min_run={MIN_RUN_LEN}, MA={MA_WINDOW})...\n")
for f in tqdm(test_files):
    path = os.path.join(TEST_DIR, f)
    y = load_binary_labels(path)
    pos_idx = np.where(y == 1)[0]
    has_pos = (len(pos_idx) > 0)

    best_idx, best_score, probs, sm = pick_best_frame(
        path, model, ma_window=MA_WINDOW, high=HYST_HIGH, low=HYST_LOW, min_run=MIN_RUN_LEN, allow_none=ALLOW_NONE
    )

    if best_idx == -1:
        # no suitable frame predicted
        dist = 0
        bg_correct = (not has_pos)  # correct only if truly no positive frames
        hit0 = bg_correct
    else:
        if has_pos:
            dist = int(np.min(np.abs(pos_idx - best_idx)))
            hit0 = (dist <= 0)
            bg_correct = False
        else:
            dist = 0
            hit0 = (y[best_idx] == 0)
            bg_correct = hit0

    details.append((f, best_idx, best_score, dist, hit0, has_pos, bg_correct))

# ---------------------------------
# Summary
# ---------------------------------
print("\n====== Summary ======")
for tol in EVAL_TOLS:
    correct_t = 0
    for _, _, _, dist, _, has_pos, bg_correct in details:
        if has_pos:
            correct_t += int(dist <= tol)
        else:
            correct_t += int(bg_correct)  # background-only sweeps handled separately
    print(f"Top-1 accuracy @ tol=±{tol}: {100.0*correct_t/len(details):.1f}%")

print(f"(MA={MA_WINDOW}, plateau high={HYST_HIGH}, low={HYST_LOW}, min_run={MIN_RUN_LEN}, allow_none={ALLOW_NONE})")

# Show all cases
for f, bi, bs, d, h0, _, _ in details:
    tag = "MISS" if (d > 0) else "HIT"
    print(f"{f}: best={bi:4d}  prob={bs:.3f}  dist_to_pos={d:3d}  {tag}")


Evaluating 45 cases (plateau: high=0.9, low=0.8, min_run=5, MA=11)...



100%|██████████| 45/45 [00:43<00:00,  1.05it/s]


Top-1 accuracy @ tol=±0: 80.0%
Top-1 accuracy @ tol=±3: 84.4%
Top-1 accuracy @ tol=±5: 84.4%
(MA=11, plateau high=0.9, low=0.8, min_run=5, allow_none=False)
d42fb920-5df1-4341-93df-480c17355e44.npz: best= 801  prob=0.920  dist_to_pos=  0  HIT
d5471cfd-6090-4d42-9a95-67ccbfbf612e.npz: best=  45  prob=0.950  dist_to_pos=  0  HIT
d571d4e1-ff80-44b9-a481-07961c6a1208.npz: best=  49  prob=0.934  dist_to_pos=  0  HIT
d5c3cfee-53ac-4021-8c1b-098c189f630e.npz: best= 168  prob=0.959  dist_to_pos=  0  HIT
d5f8c859-de93-4a50-b324-1ae4ad0267d4.npz: best=  70  prob=0.956  dist_to_pos=  0  HIT
d624338f-d09b-4bda-bbc3-3fa417015d6b.npz: best=  72  prob=0.967  dist_to_pos=  0  HIT
d77b6ece-da17-4f88-818c-0c7340b3e54f.npz: best=  57  prob=0.889  dist_to_pos=  0  HIT
d812091a-3635-4d51-9290-6adb3aa8681e.npz: best=  35  prob=0.927  dist_to_pos=  0  HIT
d8c3665a-4dc3-40ce-b716-f30aab365332.npz: best= 174  prob=0.622  dist_to_pos=417  MISS
db9d468d-cb20-4d5e-b059-31728f5950e6.npz: best= 335  prob=0.974  di




In [34]:
import os, numpy as np, torch
import torch.nn.functional as F
from torch import amp
from tqdm import tqdm

# -----------------------------
# Config (edit these)
# -----------------------------
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH  = "D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_anika.pth"
TEST_DIR    = "D:/dataset/converted_classifier_npz_compact"   # full 840-frame npz files
START_IDX   = 255   # where your test split starts
NUM_CASES   = 45
IMAGE_SIZE  = 224
MA_WINDOW   = 13    # temporal smoothing (moving average)
TOPK        = 1     # use top-1 for "best frame"
THRESH      = 0.90  # or e.g. 0.6 to require minimum smoothed prob
TOLERANCE   = 1     # 0 = exact frame; try 3 for ±3 frames

# -----------------------------
# Utils (match your training preprocessing)
# -----------------------------
def letterbox_to_square_tensor(x: torch.Tensor, size=IMAGE_SIZE) -> torch.Tensor:
    """x: (B,1,H,W) or (1,H,W) in [0,1] -> square pad -> resize to (..,1,size,size)"""
    is_batched = (x.dim() == 4)
    if not is_batched: x = x.unsqueeze(0)
    _, _, H, W = x.shape
    s = max(H, W)
    pad_h = (s - H) // 2
    pad_w = (s - W) // 2
    x = F.pad(x, (pad_w, s-W-pad_w, pad_h, s-H-pad_h))
    x = F.interpolate(x, size=(size, size), mode="bilinear", align_corners=False)
    return x if is_batched else x.squeeze(0)

def moving_average(a: np.ndarray, w: int = 7):
    if w <= 1: return a
    k = np.ones(w, dtype=np.float32) / w
    return np.convolve(a, k, mode="same")

@torch.no_grad()
def predict_probs_over_frames(npz_path, model, batch_size=128):
    d = np.load(npz_path, mmap_mode="r")
    frames = d["image"]  # (T,H,W) uint8
    Tn = len(frames)
    probs = np.zeros(Tn, dtype=np.float32)
    off = 0
    while off < Tn:
        chunk = frames[off:off+batch_size]
        x = torch.from_numpy(chunk).unsqueeze(1).float() / 255.0      # (B,1,H,W) [0,1]
        x = letterbox_to_square_tensor(x, size=IMAGE_SIZE)            # (B,1,S,S)
        x = (x - 0.5) / 0.5                                           # [-1,1]
        x = x.to(DEVICE)

        with amp.autocast(device_type="cuda", enabled=(DEVICE=="cuda")):
            logits = model(x)
            p = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()
        probs[off:off+len(p)] = p
        off += len(p)
    return probs

@torch.no_grad()
def pick_best_frame(npz_path, model, ma_window=7, topk=1, thresh=None):
    """Returns best_idx (int), best_score (float), idx_topk (np.ndarray), raw_probs, smoothed_probs."""
    probs = predict_probs_over_frames(npz_path, model)
    sm = moving_average(probs, w=ma_window)
    idx_sorted = np.argsort(-sm)
    idx_topk = idx_sorted[:topk]
    if thresh is not None:
        idx_topk = np.array([i for i in idx_topk if sm[i] >= thresh], dtype=int)
        if len(idx_topk) == 0:
            idx_topk = np.array([int(np.argmax(sm))], dtype=int)
    best_idx = int(idx_topk[0])
    best_score = float(sm[best_idx])
    return best_idx, best_score, idx_topk, probs, sm

def load_binary_labels(npz_path):
    y = np.load(npz_path, mmap_mode="r")["label"].astype(np.int64)
    y[y == 2] = 1
    return y

# -----------------------------
# Load model
# -----------------------------
model = FrameClassifier(num_classes=2).to(DEVICE)
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
state = ckpt.get("model_state", ckpt)  # support plain state_dict too
model.load_state_dict(state)
model.eval()

# -----------------------------
# Pick 45 test files and evaluate
# -----------------------------
files = sorted([f for f in os.listdir(TEST_DIR) if f.endswith(".npz")])
test_files = files[START_IDX:START_IDX+NUM_CASES]
assert len(test_files) > 0, "No test files selected."

correct = 0
details = []

print(f"Evaluating {len(test_files)} cases (tolerance=±{TOLERANCE})...\n")
for f in tqdm(test_files):
    path = os.path.join(TEST_DIR, f)
    y = load_binary_labels(path)
    pos_idx = np.where(y == 1)[0]  # ground truth positive frames

    best_idx, best_score, idx_topk, probs, sm = pick_best_frame(
        path, model, ma_window=MA_WINDOW, topk=TOPK, thresh=THRESH
    )

    if len(pos_idx) == 0:
        hit = (y[best_idx] == 0)  # no positives: "correct" if best is background
        dist = 0
    else:
        # exact or ±tolerance match
        dist = int(np.min(np.abs(pos_idx - best_idx)))
        hit = (dist <= TOLERANCE)

    correct += int(hit)
    details.append((f, best_idx, best_score, dist, hit))

# -----------------------------
# Summary
# -----------------------------
print("\n====== Summary ======")
print(f"Correct best-frame picks: {correct}/{len(test_files)} ({100.0*correct/len(test_files):.1f}%)")
print(f"(tolerance = ±{TOLERANCE}, MA={MA_WINDOW}, topk={TOPK}, thresh={THRESH})")

# Optional: show a few per-case lines
for f, bi, bs, d, h in details[:45]:
    print(f"{f}: best={bi:4d}  prob={bs:.3f}  dist_to_pos={d:3d}  {'HIT' if h else 'MISS'}")


Evaluating 45 cases (tolerance=±1)...



100%|██████████| 45/45 [00:43<00:00,  1.04it/s]


Correct best-frame picks: 36/45 (80.0%)
(tolerance = ±1, MA=13, topk=1, thresh=0.9)
d42fb920-5df1-4341-93df-480c17355e44.npz: best= 802  prob=0.904  dist_to_pos=  0  HIT
d5471cfd-6090-4d42-9a95-67ccbfbf612e.npz: best=  46  prob=0.951  dist_to_pos=  0  HIT
d571d4e1-ff80-44b9-a481-07961c6a1208.npz: best= 342  prob=0.932  dist_to_pos=  0  HIT
d5c3cfee-53ac-4021-8c1b-098c189f630e.npz: best= 479  prob=0.944  dist_to_pos=  3  MISS
d5f8c859-de93-4a50-b324-1ae4ad0267d4.npz: best=  69  prob=0.952  dist_to_pos=  0  HIT
d624338f-d09b-4bda-bbc3-3fa417015d6b.npz: best=  69  prob=0.963  dist_to_pos=  0  HIT
d77b6ece-da17-4f88-818c-0c7340b3e54f.npz: best=  57  prob=0.799  dist_to_pos=  0  HIT
d812091a-3635-4d51-9290-6adb3aa8681e.npz: best=  37  prob=0.863  dist_to_pos=  0  HIT
d8c3665a-4dc3-40ce-b716-f30aab365332.npz: best= 595  prob=0.642  dist_to_pos=  0  HIT
db9d468d-cb20-4d5e-b059-31728f5950e6.npz: best= 335  prob=0.945  dist_to_pos=  0  HIT
dc0cbbdf-e4bb-4de5-958a-10576129e440.npz: best=  50  p


