# Acouslic-AI: SimCLR Pretraining → Frame Classifier (Ultrasound)

## 0. Setup

In [1]:

!python --version
import sys, torch, torchvision
print('Python:', sys.version)
print('PyTorch:', torch.__version__)
print('Torchvision:', torchvision.__version__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device


Python 3.10.18
Python: 3.10.18 | packaged by Anaconda, Inc. | (main, Jun  5 2025, 13:08:55) [MSC v.1929 64 bit (AMD64)]
PyTorch: 2.8.0+cu128
Torchvision: 0.23.0+cu128


'cuda'

## 1. Paths & Config

In [2]:

# Adjust these to your environment
NPZ_DIR = "D:/dataset/npz_80"
SIMCLR_SAVE = "D:/acouslic-ai-cse4622/saved_weights/simclr_resnet50_ultrasound.pth"
CLS_SAVE = "D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_from_simclr.pth"

# Splits (no leakage)
TRAIN_FILES = slice(0, 210)   # first 210 npz for train
VAL_FILES   = slice(210, 255) # next 45 npz for val (adjust if needed)

SEED = 1337


## 2. Imports

In [3]:

import os, math, random, time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import torchvision.models as models

from tqdm import tqdm

# Reproducibility
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


## 3. Dataset: `NPZFrameDataset`

In [4]:

class NPZFrameDataset(Dataset):
    """Loads frames from .npz files and returns (1×H×W) tensors resized to 224×224.
    Labels are kept as 0/1 (2→1 as in your original pipeline).
    """
    def __init__(self, npz_dir, files, transform=None):
        self.samples = []
        self.transform = transform

        for f in files:
            path = os.path.join(npz_dir, f)
            case = np.load(path, mmap_mode='r')  # lazy read
            images = case['image'].astype(np.float32)   # (F,H,W)
            labels = case['label'].astype(np.int64)     # (F,)

            # normalize [0,1] per sweep
            images = (images - images.min()) / (images.max() - images.min() + 1e-8)
            labels = labels.copy()
            labels[labels == 2] = 1  # collapse suboptimal→present for 2-class

            for img, lbl in zip(images, labels):
                self.samples.append((img, int(lbl)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img, lbl = self.samples[idx]

        # (H,W) → (1,H,W)
        img = torch.from_numpy(img).unsqueeze(0)

        # resize to (1,224,224)
        img = F.interpolate(img.unsqueeze(0),
                            size=(224, 224),
                            mode='bilinear',
                            align_corners=False).squeeze(0)

        if self.transform is not None:
            img = self.transform(img)   # tensor → tensor augmentations

        return img, torch.tensor(lbl).long()


## 4. Ultrasound-aware augmentations for SimCLR

In [5]:

class SpeckleNoise:
    def __init__(self, sigma=0.08, p=0.5):
        self.sigma, self.p = sigma, p
    def __call__(self, x):
        # x: tensor [1,H,W] in [0,1]
        if random.random() < self.p:
            noise = torch.randn_like(x) * self.sigma
            x = x * (1.0 + noise)
            x = x.clamp(0, 1)
        return x

class LineDropout:
    """Randomly zero out a few vertical scanlines (ultrasound-like)."""
    def __init__(self, p=0.25, max_lines=5, max_width=2):
        self.p, self.max_lines, self.max_width = p, max_lines, max_width
    def __call__(self, x):
        if random.random() >= self.p: return x
        C, H, W = x.shape
        k = random.randint(1, self.max_lines)
        for _ in range(k):
            col = random.randint(0, W-1)
            w = random.randint(1, self.max_width)
            x[:, :, max(0,col-w):min(W,col+w+1)] = 0.0
        return x

def get_simclr_transform():
    return T.Compose([
        T.RandomResizedCrop(224, scale=(0.6, 1.0), ratio=(0.9, 1.1)),
        T.RandomHorizontalFlip(p=0.5),
        T.RandomRotation(degrees=15),
        T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.2),
        SpeckleNoise(sigma=0.08, p=0.5),
        LineDropout(p=0.25, max_lines=5, max_width=2),
        T.RandomApply([T.Lambda(lambda x: (x + 0.02*torch.randn_like(x)).clamp(0,1))], p=0.2),
        T.Normalize(mean=[0.5], std=[0.5]),
    ])


## 5. SimCLR two-view wrapper

In [6]:

class SimCLRWrapper(Dataset):
    """Wraps a base dataset to emit two augmented views per sample."""
    def __init__(self, base: Dataset, aug_a, aug_b=None):
        self.base = base
        self.aug_a = aug_a
        self.aug_b = aug_b if aug_b is not None else aug_a
    def __len__(self): return len(self.base)
    def __getitem__(self, idx):
        x, _ = self.base[idx]      # ignore label in pretraining
        v1 = self.aug_a(x.clone())
        v2 = self.aug_b(x.clone())
        return v1, v2


## 6. SimCLR encoder + projector (ResNet50 grayscale)

In [7]:

def _resnet50_imagenet_grayscale():
    # Handle torchvision API differences for weights arg
    try:
        weights = models.ResNet50_Weights.IMAGENET1K_V1
        enc = models.resnet50(weights=weights)
    except Exception:
        enc = models.resnet50(pretrained=True)
    enc.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    return enc

class SimCLRModel(nn.Module):
    def __init__(self, proj_out=128):
        super().__init__()
        enc = _resnet50_imagenet_grayscale()
        # keep up to global avgpool (remove final fc)
        self.encoder = nn.Sequential(*list(enc.children())[:-1])  # [B,2048,1,1]
        feat_dim = 2048
        self.projector = nn.Sequential(
            nn.Linear(feat_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, proj_out)
        )
    def forward(self, x):
        h = self.encoder(x).flatten(1)     # [B,2048]
        z = self.projector(h)              # [B,proj_out]
        z = F.normalize(z, dim=1)
        return h, z


## 7. NT-Xent (InfoNCE) loss

In [8]:
def nt_xent_loss(z1, z2, temperature=0.2):
    z1 = F.normalize(z1, dim=1).to(torch.float32)
    z2 = F.normalize(z2, dim=1).to(torch.float32)
    N = z1.size(0)
    z = torch.cat([z1, z2], dim=0)                 # [2N, D], fp32
    sim = (z @ z.t()) / float(temperature)         # [2N, 2N], fp32
    mask = torch.eye(2*N, device=z.device, dtype=torch.bool)
    sim = sim.masked_fill(mask, float('-inf'))     # safe in fp32

    pos_idx = torch.cat([torch.arange(N, 2*N, device=z.device),
                         torch.arange(0, N, device=z.device)], dim=0)
    labels = pos_idx
    return F.cross_entropy(sim, labels)


## 8. Pretrain SimCLR

In [9]:

def build_simclr_loaders(batch_size=128, num_workers=0):
    files = sorted(os.listdir(NPZ_DIR))
    train_files = files[TRAIN_FILES]
    base_ds = NPZFrameDataset(NPZ_DIR, train_files, transform=None)
    aug = get_simclr_transform()
    train_ds = SimCLRWrapper(base_ds, aug)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    return train_loader

from tqdm import tqdm

def pretrain_simclr(epochs=50, batch_size=128, accum_steps=1, lr=3e-4,
                    temperature=0.2, num_workers=0):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    train_loader = build_simclr_loaders(batch_size=batch_size, num_workers=num_workers)

    model = SimCLRModel(proj_out=128).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

    # New AMP API (removes deprecation warning)
    scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())

    model.train()
    os.makedirs(os.path.dirname(SIMCLR_SAVE), exist_ok=True)

    for epoch in range(1, epochs+1):
        pbar = tqdm(train_loader, desc=f"[SimCLR] Epoch {epoch}/{epochs}", leave=True)
        running = 0.0
        opt.zero_grad(set_to_none=True)

        for step, (v1, v2) in enumerate(pbar, start=1):
            v1, v2 = v1.to(device), v2.to(device)

            # New autocast API; forward in AMP, loss computed stably in nt_xent_loss (fp32)
            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                _, z1 = model(v1)
                _, z2 = model(v2)

            loss = nt_xent_loss(z1, z2, temperature=temperature) / max(1, accum_steps)

            scaler.scale(loss).backward()
            if step % max(1,accum_steps) == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

            running += loss.item() * max(1,accum_steps)
            avg_loss = running / step
            curr_lr = opt.param_groups[0]['lr']
            pbar.set_postfix({'loss': f"{avg_loss:.4f}", 'lr': f"{curr_lr:.2e}"})

        sched.step()
        torch.save({'epoch': epoch, 'model': model.state_dict()}, SIMCLR_SAVE)
        print(f"Epoch {epoch}/{epochs} finished — avg loss {avg_loss:.4f}")

    print(f"✅ SimCLR pretraining complete. Saved to: {SIMCLR_SAVE}")



### ⏯ Run pretraining

In [10]:

# Example: reduce epochs/batch if memory-limited
pretrain_simclr(epochs=50, batch_size=64, accum_steps=2, lr=3e-4, temperature=0.2, num_workers=0)


[SimCLR] Epoch 1/50: 100%|██████████| 262/262 [02:20<00:00,  1.87it/s, loss=1.4281, lr=3.00e-04]


Epoch 1/50 finished — avg loss 1.4281


[SimCLR] Epoch 2/50: 100%|██████████| 262/262 [02:21<00:00,  1.85it/s, loss=1.0477, lr=3.00e-04]


Epoch 2/50 finished — avg loss 1.0477


[SimCLR] Epoch 3/50: 100%|██████████| 262/262 [02:18<00:00,  1.89it/s, loss=0.9838, lr=2.99e-04]


Epoch 3/50 finished — avg loss 0.9838


[SimCLR] Epoch 4/50: 100%|██████████| 262/262 [02:17<00:00,  1.91it/s, loss=0.9659, lr=2.97e-04]


Epoch 4/50 finished — avg loss 0.9659


[SimCLR] Epoch 5/50: 100%|██████████| 262/262 [02:23<00:00,  1.83it/s, loss=0.9407, lr=2.95e-04]


Epoch 5/50 finished — avg loss 0.9407


[SimCLR] Epoch 6/50: 100%|██████████| 262/262 [02:16<00:00,  1.92it/s, loss=0.9373, lr=2.93e-04]


Epoch 6/50 finished — avg loss 0.9373


[SimCLR] Epoch 7/50: 100%|██████████| 262/262 [02:20<00:00,  1.86it/s, loss=0.9525, lr=2.89e-04]


Epoch 7/50 finished — avg loss 0.9525


[SimCLR] Epoch 8/50: 100%|██████████| 262/262 [02:26<00:00,  1.79it/s, loss=0.9278, lr=2.86e-04]


Epoch 8/50 finished — avg loss 0.9278


[SimCLR] Epoch 9/50: 100%|██████████| 262/262 [02:26<00:00,  1.79it/s, loss=0.9116, lr=2.81e-04]


Epoch 9/50 finished — avg loss 0.9116


[SimCLR] Epoch 10/50: 100%|██████████| 262/262 [02:20<00:00,  1.87it/s, loss=0.9094, lr=2.77e-04]


Epoch 10/50 finished — avg loss 0.9094


[SimCLR] Epoch 11/50: 100%|██████████| 262/262 [02:21<00:00,  1.85it/s, loss=0.9007, lr=2.71e-04]


Epoch 11/50 finished — avg loss 0.9007


[SimCLR] Epoch 12/50: 100%|██████████| 262/262 [02:27<00:00,  1.78it/s, loss=0.8961, lr=2.66e-04]


Epoch 12/50 finished — avg loss 0.8961


[SimCLR] Epoch 13/50: 100%|██████████| 262/262 [02:24<00:00,  1.81it/s, loss=0.8913, lr=2.59e-04]


Epoch 13/50 finished — avg loss 0.8913


[SimCLR] Epoch 14/50: 100%|██████████| 262/262 [02:20<00:00,  1.87it/s, loss=0.8905, lr=2.53e-04]


Epoch 14/50 finished — avg loss 0.8905


[SimCLR] Epoch 15/50: 100%|██████████| 262/262 [02:22<00:00,  1.84it/s, loss=0.8832, lr=2.46e-04]


Epoch 15/50 finished — avg loss 0.8832


[SimCLR] Epoch 16/50: 100%|██████████| 262/262 [02:28<00:00,  1.76it/s, loss=0.8817, lr=2.38e-04]


Epoch 16/50 finished — avg loss 0.8817


[SimCLR] Epoch 17/50: 100%|██████████| 262/262 [02:25<00:00,  1.80it/s, loss=0.8829, lr=2.30e-04]


Epoch 17/50 finished — avg loss 0.8829


[SimCLR] Epoch 18/50: 100%|██████████| 262/262 [02:20<00:00,  1.87it/s, loss=0.8808, lr=2.22e-04]


Epoch 18/50 finished — avg loss 0.8808


[SimCLR] Epoch 19/50: 100%|██████████| 262/262 [02:29<00:00,  1.76it/s, loss=0.8788, lr=2.14e-04]


Epoch 19/50 finished — avg loss 0.8788


[SimCLR] Epoch 20/50: 100%|██████████| 262/262 [02:30<00:00,  1.74it/s, loss=0.8756, lr=2.05e-04]


Epoch 20/50 finished — avg loss 0.8756


[SimCLR] Epoch 21/50: 100%|██████████| 262/262 [02:27<00:00,  1.78it/s, loss=0.8814, lr=1.96e-04]


Epoch 21/50 finished — avg loss 0.8814


[SimCLR] Epoch 22/50: 100%|██████████| 262/262 [02:33<00:00,  1.70it/s, loss=0.8772, lr=1.87e-04]


Epoch 22/50 finished — avg loss 0.8772


[SimCLR] Epoch 23/50: 100%|██████████| 262/262 [02:27<00:00,  1.77it/s, loss=0.8758, lr=1.78e-04]


Epoch 23/50 finished — avg loss 0.8758


[SimCLR] Epoch 24/50: 100%|██████████| 262/262 [02:34<00:00,  1.70it/s, loss=0.8727, lr=1.69e-04]


Epoch 24/50 finished — avg loss 0.8727


[SimCLR] Epoch 25/50: 100%|██████████| 262/262 [02:24<00:00,  1.82it/s, loss=0.8668, lr=1.59e-04]


Epoch 25/50 finished — avg loss 0.8668


[SimCLR] Epoch 26/50: 100%|██████████| 262/262 [02:21<00:00,  1.85it/s, loss=0.8677, lr=1.50e-04]


Epoch 26/50 finished — avg loss 0.8677


[SimCLR] Epoch 27/50: 100%|██████████| 262/262 [02:30<00:00,  1.74it/s, loss=0.8673, lr=1.41e-04]


Epoch 27/50 finished — avg loss 0.8673


[SimCLR] Epoch 28/50: 100%|██████████| 262/262 [02:31<00:00,  1.73it/s, loss=0.8639, lr=1.31e-04]


Epoch 28/50 finished — avg loss 0.8639


[SimCLR] Epoch 29/50: 100%|██████████| 262/262 [02:30<00:00,  1.75it/s, loss=0.8636, lr=1.22e-04]


Epoch 29/50 finished — avg loss 0.8636


[SimCLR] Epoch 30/50: 100%|██████████| 262/262 [02:33<00:00,  1.71it/s, loss=0.8602, lr=1.13e-04]


Epoch 30/50 finished — avg loss 0.8602


[SimCLR] Epoch 31/50: 100%|██████████| 262/262 [02:31<00:00,  1.72it/s, loss=0.8582, lr=1.04e-04]


Epoch 31/50 finished — avg loss 0.8582


[SimCLR] Epoch 32/50: 100%|██████████| 262/262 [02:32<00:00,  1.72it/s, loss=0.8584, lr=9.48e-05]


Epoch 32/50 finished — avg loss 0.8584


[SimCLR] Epoch 33/50: 100%|██████████| 262/262 [02:32<00:00,  1.72it/s, loss=0.8575, lr=8.61e-05]


Epoch 33/50 finished — avg loss 0.8575


[SimCLR] Epoch 34/50: 100%|██████████| 262/262 [02:33<00:00,  1.71it/s, loss=0.8543, lr=7.77e-05]


Epoch 34/50 finished — avg loss 0.8543


[SimCLR] Epoch 35/50: 100%|██████████| 262/262 [02:32<00:00,  1.72it/s, loss=0.8553, lr=6.96e-05]


Epoch 35/50 finished — avg loss 0.8553


[SimCLR] Epoch 36/50: 100%|██████████| 262/262 [02:29<00:00,  1.75it/s, loss=0.8580, lr=6.18e-05]


Epoch 36/50 finished — avg loss 0.8580


[SimCLR] Epoch 37/50: 100%|██████████| 262/262 [02:30<00:00,  1.75it/s, loss=0.8528, lr=5.44e-05]


Epoch 37/50 finished — avg loss 0.8528


[SimCLR] Epoch 38/50: 100%|██████████| 262/262 [02:22<00:00,  1.84it/s, loss=0.8522, lr=4.73e-05]


Epoch 38/50 finished — avg loss 0.8522


[SimCLR] Epoch 39/50: 100%|██████████| 262/262 [02:26<00:00,  1.79it/s, loss=0.8517, lr=4.07e-05]


Epoch 39/50 finished — avg loss 0.8517


[SimCLR] Epoch 40/50: 100%|██████████| 262/262 [02:31<00:00,  1.73it/s, loss=0.8501, lr=3.44e-05]


Epoch 40/50 finished — avg loss 0.8501


[SimCLR] Epoch 41/50: 100%|██████████| 262/262 [02:29<00:00,  1.76it/s, loss=0.8510, lr=2.86e-05]


Epoch 41/50 finished — avg loss 0.8510


[SimCLR] Epoch 42/50: 100%|██████████| 262/262 [02:32<00:00,  1.72it/s, loss=0.8468, lr=2.34e-05]


Epoch 42/50 finished — avg loss 0.8468


[SimCLR] Epoch 43/50: 100%|██████████| 262/262 [02:31<00:00,  1.73it/s, loss=0.8491, lr=1.86e-05]


Epoch 43/50 finished — avg loss 0.8491


[SimCLR] Epoch 44/50: 100%|██████████| 262/262 [02:33<00:00,  1.71it/s, loss=0.8477, lr=1.43e-05]


Epoch 44/50 finished — avg loss 0.8477


[SimCLR] Epoch 45/50: 100%|██████████| 262/262 [02:30<00:00,  1.74it/s, loss=0.8449, lr=1.05e-05]


Epoch 45/50 finished — avg loss 0.8449


[SimCLR] Epoch 46/50: 100%|██████████| 262/262 [02:33<00:00,  1.71it/s, loss=0.8474, lr=7.34e-06]


Epoch 46/50 finished — avg loss 0.8474


[SimCLR] Epoch 47/50: 100%|██████████| 262/262 [02:33<00:00,  1.71it/s, loss=0.8448, lr=4.71e-06]


Epoch 47/50 finished — avg loss 0.8448


[SimCLR] Epoch 48/50: 100%|██████████| 262/262 [02:30<00:00,  1.74it/s, loss=0.8454, lr=2.66e-06]


Epoch 48/50 finished — avg loss 0.8454


[SimCLR] Epoch 49/50: 100%|██████████| 262/262 [02:32<00:00,  1.72it/s, loss=0.8455, lr=1.18e-06]


Epoch 49/50 finished — avg loss 0.8455


[SimCLR] Epoch 50/50: 100%|██████████| 262/262 [02:33<00:00,  1.71it/s, loss=0.8463, lr=2.96e-07]


Epoch 50/50 finished — avg loss 0.8463
✅ SimCLR pretraining complete. Saved to: D:/acouslic-ai-cse4622/saved_weights/simclr_resnet50_ultrasound.pth


## 9. Fine-tune the Frame Classifier (2-class CE)

In [11]:

# Supervised augmentations (lighter than SSL)
train_transform = T.Compose([
    T.RandomResizedCrop(224, scale=(0.8, 1.0), ratio=(0.95, 1.05)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=10),
    T.RandomApply([T.GaussianBlur(3)], p=0.15),
    T.Normalize(mean=[0.5], std=[0.5]),
])
val_transform = T.Compose([
    T.Normalize(mean=[0.5], std=[0.5]),
])

class FrameClassifier(nn.Module):
    def __init__(self, simclr_ckpt_path=None, num_classes=2):
        super().__init__()
        # reuse SimCLR encoder definition
        enc = _resnet50_imagenet_grayscale()
        self.encoder = nn.Sequential(*list(enc.children())[:-1])  # [B,2048,1,1]
        if simclr_ckpt_path and os.path.isfile(simclr_ckpt_path):
            ckpt = torch.load(simclr_ckpt_path, map_location='cpu')
            # load encoder params (ignore projector)
            simclr_state = ckpt.get('model', ckpt)
            # Extract keys that belong to encoder.*
            encoder_state = {k.replace('encoder.', ''): v for k, v in simclr_state.items() if k.startswith('encoder.')}
            self.encoder.load_state_dict(encoder_state, strict=False)
            print(f"Loaded SimCLR encoder weights from {simclr_ckpt_path}")
        self.head = nn.Linear(2048, num_classes)

    def forward(self, x):
        h = self.encoder(x).flatten(1)  # [B,2048]
        return self.head(h)

def build_cls_loaders(batch_size=16, num_workers=0):
    files = sorted(os.listdir(NPZ_DIR))
    train_files = files[TRAIN_FILES]
    val_files = files[VAL_FILES]
    train_ds = NPZFrameDataset(NPZ_DIR, train_files, transform=train_transform)
    val_ds   = NPZFrameDataset(NPZ_DIR, val_files,   transform=val_transform)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, val_loader

def finetune_classifier(epochs=20, batch_size=16, lr=1e-4, weight_decay=1e-4, num_workers=0, patience=5):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    train_loader, val_loader = build_cls_loaders(batch_size=batch_size, num_workers=num_workers)

    model = FrameClassifier(simclr_ckpt_path=SIMCLR_SAVE, num_classes=2).to(device)
    criterion = nn.CrossEntropyLoss()
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

    best_acc, no_improve = 0.0, 0
    os.makedirs(os.path.dirname(CLS_SAVE), exist_ok=True)

    for epoch in range(1, epochs+1):
        # ---- Train ----
        model.train()
        pbar = tqdm(train_loader, desc=f"[CLS] Epoch {epoch}/{epochs}")
        running = 0.0
        for imgs, labels in pbar:
            imgs, labels = imgs.to(device), labels.to(device)
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                logits = model(imgs)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
            running += loss.item()
            pbar.set_postfix(loss=running/max(1,len(pbar)))

        # ---- Val ----
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                logits = model(imgs)
                preds = torch.argmax(logits, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        val_acc = 100.0 * correct / max(1,total)
        print(f"Val Acc: {val_acc:.2f}%")

        sched.step()

        if val_acc > best_acc:
            best_acc, no_improve = val_acc, 0
            torch.save(model.state_dict(), CLS_SAVE)
            print(f"✅ Saved best model @ {best_acc:.2f}% → {CLS_SAVE}")
        else:
            no_improve += 1
            if no_improve >= patience:
                print("⏹️ Early stopping.")
                break

    print(f"Done. Best Val Acc = {best_acc:.2f}%  → {CLS_SAVE}")


### ⏯ Run fine-tuning

In [12]:

finetune_classifier(epochs=20, batch_size=16, lr=1e-4, weight_decay=1e-4, num_workers=0, patience=5)


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


Loaded SimCLR encoder weights from D:/acouslic-ai-cse4622/saved_weights/simclr_resnet50_ultrasound.pth


  with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
[CLS] Epoch 1/20: 100%|██████████| 1050/1050 [01:36<00:00, 10.83it/s, loss=0.249]


Val Acc: 92.42%
✅ Saved best model @ 92.42% → D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_from_simclr.pth


[CLS] Epoch 2/20: 100%|██████████| 1050/1050 [01:37<00:00, 10.71it/s, loss=0.171]


Val Acc: 92.08%


[CLS] Epoch 3/20: 100%|██████████| 1050/1050 [01:43<00:00, 10.15it/s, loss=0.14] 


Val Acc: 93.61%
✅ Saved best model @ 93.61% → D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_from_simclr.pth


[CLS] Epoch 4/20: 100%|██████████| 1050/1050 [01:39<00:00, 10.52it/s, loss=0.125]


Val Acc: 92.36%


[CLS] Epoch 5/20: 100%|██████████| 1050/1050 [01:39<00:00, 10.58it/s, loss=0.107] 


Val Acc: 93.69%
✅ Saved best model @ 93.69% → D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_from_simclr.pth


[CLS] Epoch 6/20: 100%|██████████| 1050/1050 [01:39<00:00, 10.60it/s, loss=0.0934]


Val Acc: 94.19%
✅ Saved best model @ 94.19% → D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_from_simclr.pth


[CLS] Epoch 7/20: 100%|██████████| 1050/1050 [01:41<00:00, 10.35it/s, loss=0.0837]


Val Acc: 93.08%


[CLS] Epoch 8/20: 100%|██████████| 1050/1050 [01:43<00:00, 10.19it/s, loss=0.0674]


Val Acc: 93.44%


[CLS] Epoch 9/20: 100%|██████████| 1050/1050 [01:41<00:00, 10.38it/s, loss=0.0649]


Val Acc: 92.28%


[CLS] Epoch 10/20: 100%|██████████| 1050/1050 [01:41<00:00, 10.35it/s, loss=0.0511]


Val Acc: 92.33%


[CLS] Epoch 11/20: 100%|██████████| 1050/1050 [01:43<00:00, 10.10it/s, loss=0.0421]


Val Acc: 91.31%
⏹️ Early stopping.
Done. Best Val Acc = 94.19%  → D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_from_simclr.pth


## 10. Inference helper: pick best frame in a sweep

In [13]:

import torch.nn.functional as F

def pick_best_frame(npz_path, model_ckpt=CLS_SAVE, batch_size=64):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Load classifier
    clf = FrameClassifier(simclr_ckpt_path=None, num_classes=2).to(device)
    clf.load_state_dict(torch.load(model_ckpt, map_location=device))
    clf.eval()

    # Deterministic preprocessing
    def preprocess(batch):
        batch = torch.from_numpy(batch.astype(np.float32))
        # normalize [0,1]
        mn, mx = batch.min(), batch.max()
        batch = (batch - mn) / (mx - mn + 1e-8)
        batch = batch.unsqueeze(1)  # B×1×H×W
        batch = F.interpolate(batch, size=(224,224), mode='bilinear', align_corners=False)
        batch = (batch - 0.5) / 0.5
        return batch

    case = np.load(npz_path, mmap_mode='r')
    images = case['image']  # (F,H,W)

    best_idx, best_score = None, -1.0
    with torch.no_grad():
        for start in range(0, len(images), batch_size):
            chunk = images[start:start+batch_size]
            x = preprocess(chunk).to(device)
            logits = clf(x)
            probs = torch.softmax(logits, dim=1)[:,1]  # P(abdomen)
            s, j = torch.max(probs, dim=0)
            s, j = s.item(), j.item()
            if s > best_score:
                best_score = s
                best_idx = start + j
    return best_idx, best_score


In [14]:
npz_path = os.path.join(NPZ_DIR, sorted(os.listdir(NPZ_DIR))[210])
idx, score = pick_best_frame(npz_path)
print('Best frame:', idx, 'score:', score)


Best frame: 20 score: 0.9978273510932922


In [15]:
import os
import csv
import math
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F


def _load_model(ckpt_path: str, num_classes: int = 2) -> Tuple[torch.nn.Module, str]:
    """
    Load the frame classifier and return (model, device).
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = FrameClassifier(num_classes=num_classes).to(device)
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state)
    model.eval()
    return model, device


def _get_label_array(case_npz: np.lib.npyio.NpzFile) -> np.ndarray:
    """
    Extract a 1D (F,) label array from common keys.
    Raises KeyError if none are found.
    """
    for key in ("label", "labels", "y", "gt", "target"):
        if key in case_npz:
            return np.asarray(case_npz[key]).reshape(-1)
    raise KeyError("No per-frame label key found (tried: label, labels, y, gt, target).")


def _predict_best_frame(npz_path: str, model: torch.nn.Module, device: str, batch_size: int = 64) -> Tuple[int, float]:
    """
    Run inference over all frames in an NPZ and return (best_frame_index, best_score).
    Assumes NPZ has key 'image' shaped (F, H, W).
    """
    case = np.load(npz_path, mmap_mode="r")
    images = case["image"]  # (F, H, W)

    best_idx, best_score = None, -1.0

    with torch.no_grad():
        for start in range(0, len(images), batch_size):
            batch = images[start:start + batch_size].astype(np.float32)     # (B,H,W)
            b = torch.from_numpy(batch).unsqueeze(1)                        # (B,1,H,W)

            # Per-frame min-max normalization to [0,1]
            b_min = b.amin(dim=(2, 3), keepdim=True)
            b_max = b.amax(dim=(2, 3), keepdim=True)
            b = (b - b_min) / (b_max - b_min + 1e-8)

            # Resize and normalize to [-1, 1]
            b = F.interpolate(b, size=(224, 224), mode="bilinear", align_corners=False)
            b = (b - 0.5) / 0.5
            b = b.to(device)

            logits = model(b)
            probs = torch.softmax(logits, dim=1)[:, 1]                      # P(positive class)

            max_prob, max_idx = torch.max(probs, dim=0)
            if max_prob.item() > best_score:
                best_score = max_prob.item()
                best_idx = start + max_idx.item()

    return int(best_idx), float(best_score)


def evaluate_fixed_range(
    root_dir: str,
    ckpt_path: str,
    start_idx: int = 255,
    end_idx: int = 300,
    batch_size: int = 64,
    positive_values: Tuple[int, ...] = (1, 2),
    save_csv: str = "range_255_300_results.csv",
    verbose: bool = True,
) -> Dict[str, float]:
    """
    Evaluate files in sorted(os.listdir(root_dir))[start_idx : end_idx+1].

    For each case:
      - Predict best frame with the classifier.
      - Extract ground-truth positives from per-frame labels inside the NPZ.
      - Compute distance to the nearest positive (0 = exact hit).
      - Record metrics and save per-case rows to CSV.

    Returns a dict with aggregate metrics and the CSV path.
    """
    # Load model once
    model, device = _load_model(ckpt_path)

    # Collect files
    files = sorted([f for f in os.listdir(root_dir) if f.lower().endswith(".npz")])
    if end_idx >= len(files):
        raise IndexError(f"end_idx={end_idx} out of range for {len(files)} files in {root_dir}")
    target_files = files[start_idx:end_idx + 1]

    if verbose:
        print(f"Evaluating files [{start_idx}:{end_idx}] → {len(target_files)} cases\n")

    rows: List[Dict[str, object]] = []
    total = 0
    hits_anypos = 0
    exact_firstpos_hits = 0
    prob_values: List[float] = []
    dists: List[int] = []

    for fname in target_files:
        path = os.path.join(root_dir, fname)
        case = np.load(path, mmap_mode="r")

        # Load labels
        try:
            labels = _get_label_array(case)  # (F,)
        except KeyError:
            if verbose:
                print(f"[WARN] {fname}: missing labels → skipped.")
            rows.append({
                "filename": fname, "n_frames": int(case["image"].shape[0]),
                "pred_idx": "", "pred_prob": "", "gt_positives": "",
                "nearest_gt": "", "abs_dist": "", "hit_anypos": "", "exact_firstpos": "",
                "note": "No per-frame labels found",
            })
            continue

        positives = np.where(np.isin(labels, positive_values))[0]  # indices of positive frames

        # Predict
        pred_idx, pred_prob = _predict_best_frame(path, model=model, device=device, batch_size=batch_size)

        total += 1
        prob_values.append(pred_prob)

        # Distance to nearest positive
        if positives.size > 0:
            nearest_gt = int(positives[np.argmin(np.abs(positives - pred_idx))])
            dist = int(abs(nearest_gt - pred_idx))
        else:
            nearest_gt = None
            dist = 0  # defined as 0 when no positives exist

        dists.append(dist)

        # Any positive hit?
        anypos = (pred_idx in positives) if positives.size > 0 else True  # no positives → trivially "not a miss"
        hits_anypos += int(anypos)

        # Exact match to first positive (argmax proxy)
        if positives.size > 0:
            first_pos = int(positives[0])
        else:
            first_pos = int(np.argmax(labels))  # fallback if all zeros
        exact_firstpos_hits += int(pred_idx == first_pos)

        # Row
        rows.append({
            "filename": fname,
            "n_frames": int(labels.shape[0]),
            "pred_idx": int(pred_idx),
            "pred_prob": round(float(pred_prob), 6),
            "gt_positives": ";".join(map(str, positives.tolist())) if positives.size > 0 else "",
            "nearest_gt": "" if nearest_gt is None else int(nearest_gt),
            "abs_dist": int(dist),
            "hit_anypos": int(anypos),
            "exact_firstpos": int(pred_idx == first_pos),
            "note": "",
        })

        if verbose:
            print(f"{fname}")
            print(f"  Frames: {len(labels)} | Pred idx: {pred_idx:4d} (p={pred_prob:.4f})")
            if positives.size > 0:
                plist = positives.tolist()
                print(f"  GT positives: {plist[:12]}{' ...' if len(plist) > 12 else ''}")
                print(f"  Nearest GT: {nearest_gt} | |dist|={dist} | Hit any +ve? {'YES' if anypos else 'NO'}\n")
            else:
                print("  GT positives: [] (no positive frames) | |dist|=0 (defined)\n")

    # Aggregates
    if total > 0:
        mae = float(np.mean(dists))
        acc_anypos = hits_anypos / total
        acc_exact_first = exact_firstpos_hits / total
        mean_prob = float(np.mean(prob_values))
    else:
        mae = math.nan
        acc_anypos = math.nan
        acc_exact_first = math.nan
        mean_prob = math.nan

    # Save CSV
    fieldnames = [
        "filename", "n_frames", "pred_idx", "pred_prob", "gt_positives",
        "nearest_gt", "abs_dist", "hit_anypos", "exact_firstpos", "note",
    ]
    with open(save_csv, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

    if verbose:
        print("====== Summary ({}–{}) ======".format(start_idx, end_idx))
        print(f"Files evaluated:      {total}")
        print(f"Mean |dist to nearest +ve|: {mae:.2f} frames" if not math.isnan(mae) else "MAE: N/A")
        print(f"Hit any positive:     {hits_anypos}/{total}  ({acc_anypos*100:.1f}%)" if not math.isnan(acc_anypos) else "Hit any +ve: N/A")
        print(f"Exact first-positive: {exact_firstpos_hits}/{total}  ({acc_exact_first*100:.1f}%)" if not math.isnan(acc_exact_first) else "Exact first-positive: N/A")
        print(f"Mean predicted prob:  {mean_prob:.4f}" if not math.isnan(mean_prob) else "Mean prob: N/A")
        print(f"Saved per-case CSV to: {os.path.abspath(save_csv)}")

    return {
        "n_evaluated": float(total),
        "mae_to_nearest_positive": mae,
        "acc_any_positive": acc_anypos,
        "acc_exact_first_positive": acc_exact_first,
        "mean_pred_prob": mean_prob,
        "csv_path": os.path.abspath(save_csv),
    }

# ---------------- Example ----------------
summary = evaluate_fixed_range(
    root_dir="D:/dataset/mult_mha_to_npz",
    ckpt_path="D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_from_simclr.pth",
    start_idx=255,
    end_idx=299,
    batch_size=64,
    positive_values=(1, 2),                     # adjust if your labels differ
    save_csv="range_255_300_results.csv",
    verbose=True,
)
print(summary)


Evaluating files [255:299] → 45 cases

d42fb920-5df1-4341-93df-480c17355e44.npz
  Frames: 840 | Pred idx:   68 (p=0.9934)
  GT positives: [799, 800, 801, 802, 803, 804, 805, 806]
  Nearest GT: 799 | |dist|=731 | Hit any +ve? NO

d5471cfd-6090-4d42-9a95-67ccbfbf612e.npz
  Frames: 840 | Pred idx:   44 (p=0.9971)
  GT positives: [42, 43, 44, 45, 46, 47, 48, 176, 177, 178, 179, 180] ...
  Nearest GT: 44 | |dist|=0 | Hit any +ve? YES

d571d4e1-ff80-44b9-a481-07961c6a1208.npz
  Frames: 840 | Pred idx:   45 (p=0.9955)
  GT positives: [43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54] ...
  Nearest GT: 45 | |dist|=0 | Hit any +ve? YES

d5c3cfee-53ac-4021-8c1b-098c189f630e.npz
  Frames: 840 | Pred idx:  626 (p=0.9988)
  GT positives: [20, 21, 22, 23, 24, 25, 164, 165, 166, 167, 168, 169] ...
  Nearest GT: 626 | |dist|=0 | Hit any +ve? YES

d5f8c859-de93-4a50-b324-1ae4ad0267d4.npz
  Frames: 840 | Pred idx:   69 (p=0.9997)
  GT positives: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75] ...
  Neare

In [6]:
from typing import Dict, Tuple
import numpy as np
import torch
import torch.nn.functional as F

def confusion_binary_from_counts(tp, fp, fn, tn) -> np.ndarray:
    return np.array([[tn, fp],
                     [fn, tp]], dtype=int)
    

def evaluate_frame_level_confusion(
    root_dir: str,
    ckpt_path: str,
    threshold: float = 0.5,
    batch_size: int = 128,
    positive_values: Tuple[int, ...] = (1, 2),
    num_classes: int = 2,
    verbose: bool = True,
) -> Dict[str, object]:
    #model, device = _load_model(ckpt_path, num_classes=num_classes)

    files = sorted([f for f in os.listdir(root_dir) if f.lower().endswith(".npz")])
    tp = fp = fn = tn = 0
    n_frames_total = 0

    with torch.no_grad():
        for fname in files:
            path = os.path.join(root_dir, fname)
            case = np.load(path, mmap_mode="r")

            # Must have per-frame labels
            try:
                labels = _get_label_array(case)  # (F,)
            except KeyError:
                if verbose:
                    print(f"[WARN] {fname}: missing labels → skipped for confusion.")
                continue

            images = case["image"].astype(np.float32)  # (F,H,W)
            F_count = images.shape[0]
            n_frames_total += F_count

            # Ground-truth (binary)
            y_true = np.isin(labels, positive_values).astype(np.int64)  # (F,)

            # Batched inference to get per-frame probs
            probs_all = []
            for start in range(0, F_count, batch_size):
                batch = images[start:start + batch_size]                 # (B,H,W)
                b = torch.from_numpy(batch).unsqueeze(1)                 # (B,1,H,W)

                # Per-frame min-max to [0,1]
                b_min = b.amin(dim=(2, 3), keepdim=True)
                b_max = b.amax(dim=(2, 3), keepdim=True)
                b = (b - b_min) / (b_max - b_min + 1e-8)

                # Resize → [-1,1]
                b = F.interpolate(b, size=(224, 224), mode="bilinear", align_corners=False)
                b = (b - 0.5) / 0.5
                b = b.to(device)

                logits = model(b)
                probs = torch.softmax(logits, dim=1)[:, 1].detach().cpu().numpy()  # (B,)
                probs_all.append(probs)

            probs_all = np.concatenate(probs_all, axis=0)                # (F,)
            y_pred = (probs_all >= threshold).astype(np.int64)

            # Update counts
            tp += int(((y_true == 1) & (y_pred == 1)).sum())
            fp += int(((y_true == 0) & (y_pred == 1)).sum())
            fn += int(((y_true == 1) & (y_pred == 0)).sum())
            tn += int(((y_true == 0) & (y_pred == 0)).sum())

    cm = confusion_binary_from_counts(tp, fp, fn, tn)

    if verbose:
        print("Frame-level confusion matrix (rows: true [0,1], cols: pred [0,1])")
        print(cm)
        prec = tp / (tp + fp + 1e-8)
        rec  = tp / (tp + fn + 1e-8)
        acc  = (tp + tn) / max(1, (tp+tn+fp+fn))
        print(f"Frames evaluated: {n_frames_total}")
        print(f"Accuracy: {acc:.4f} | Precision: {prec:.4f} | Recall: {rec:.4f}")

    return {"confusion_matrix": cm, "tp": tp, "fp": fp, "fn": fn, "tn": tn}


In [5]:
evaluate_frame_level_confusion(
    root_dir="D:/dataset/mult_mha_to_npz",
    ckpt_path="D:/acouslic-ai-cse4622/saved_weights/best_frame_classifier_from_simclr.pth",
    threshold= 0.5,
    batch_size=128,
    positive_values=(1, 2),     
    num_classes =2,
    verbose=True
)

NameError: name 'os' is not defined