In [None]:
import os, time, json, random, glob, copy
import numpy as np, cv2, matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import models
from albumentations import Compose, Resize, RandomBrightnessContrast, Rotate, HorizontalFlip, Normalize
from albumentations.pytorch import ToTensorV2
from torch.cuda.amp import autocast, GradScaler # For mixed precision training

# ---------------------------
# 0. Thiết lập tham số chung
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2

# Loss weights & thresholds
lambda_unc   = 0.5   # Weight for evidential uncertainty loss
target_risk  = 0.02  # Max allowed risk on kept samples (Currently not used directly, but good for context)
alpha_conf   = 0.05  # Conformal prediction error rate

learning_rate_p1  = 1e-4
learning_rate_p2  = 1e-3 # Tăng learning rate cho pha 2 vì chỉ huấn luyện head nhỏ
batch_size_p1     = 8   # Batch size cho pha 1
batch_size_p2     = 64  # Tăng batch size cho pha 2 (vì chỉ xử lý features, không phải ảnh gốc)
accumulation_steps = 2 # Gradient accumulation
num_epochs_p1     = 10
num_epochs_p2     = 10 # Reduced epochs for lighter phase 2

# Paths
positive_folder  = "/kaggle/input/sarscov2-ctscan-dataset/COVID"
negative_folder  = "/kaggle/input/sarscov2-ctscan-dataset/non-COVID"
checkpoint_dir   = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Set environment variable to reduce memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ---------------------------
# 1. Load & split data
# ---------------------------
def load_data_paths():
    pos = glob.glob(os.path.join(positive_folder, "*.png"))
    neg = glob.glob(os.path.join(negative_folder, "*.png"))
    random.shuffle(pos); random.shuffle(neg)
    split_p = int(0.8 * len(pos))
    split_n = int(0.8 * len(neg))
    return pos[:split_p], neg[:split_n], pos[split_p:], neg[split_n:]

train_pos, train_neg, test_pos, test_neg = load_data_paths()

# ---------------------------
# 2. Dataset
# ---------------------------
class CTScanDataset(Dataset):
    def __init__(self, pos_paths, neg_paths, transform):
        self.paths = pos_paths + neg_paths
        self.labels = [0] * len(pos_paths) + [1] * len(neg_paths)
        self.transform = transform

    def __len__(self): return len(self.paths)

    def __getitem__(self, i):
        p = self.paths[i]
        l = self.labels[i]
        img = cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB)
        img = self.transform(image=img)["image"]
        return img, l, p # Trả về cả đường dẫn để phục vụ debugging/ghi log

# ---------------------------
# 3. Model
# ---------------------------
class MobileNetV2WithUncertainty(nn.Module):
    def __init__(self, num_classes=2, p_drop=0.5, temperature=1.0):
        super().__init__()
        m = models.mobilenet_v2(pretrained=True)
        # Thêm dropout vào các layer nhất định trong backbone
        # Tắt tự động thêm dropout bằng cách thêm dropout ở cuối mỗi block invert_residual
        # for idx, layer in enumerate(m.features):
        #     if idx % 5 == 4: layer.add_module("drop", nn.Dropout(p_drop)) # Ví dụ: thêm dropout sau mỗi 5 layer

        self.backbone = m
        self.backbone.classifier = nn.Identity() # Bỏ head phân loại gốc của MobileNetV2
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout(p_drop)
        self.fc = nn.Linear(self.backbone.last_channel, num_classes)
        self.unc_head = nn.Sequential(
            nn.Linear(self.backbone.last_channel, 128), nn.ReLU(),
            nn.Dropout(p_drop), nn.Linear(128, num_classes), nn.Softplus()
        )
        self.temperature = nn.Parameter(torch.tensor(temperature))
        self.feature_grad = None
        self.feature_map = None

    def save_grad(self, grad): self.feature_grad = grad

    def forward(self, x):
        # Xử lý input có thể là ảnh hoặc feature vector
        if x.dim() == 4: # Input là ảnh (Batch, Channels, H, W)
            f = self.backbone.features(x)
            p = self.pool(f).flatten(1)
        elif x.dim() == 2: # Input là feature vector (Batch, Features)
            p = x # Không cần qua backbone và pooling nữa
            f = None # Không có feature map khi input là vector
        else:
            raise ValueError("Input tensor must be 2D (features) or 4D (image).")

        p = self.dropout(p)
        logits = self.fc(p) / self.temperature
        evidence = self.unc_head(p)
        return logits, evidence

    # Hàm forward riêng để lấy feature map cho Grad-CAM++
    def forward_with_features(self, x):
        f = self.backbone.features(x)
        if f.requires_grad:
            f.register_hook(self.save_grad)
        self.feature_map = f.clone() # Lưu feature map để tính alpha trong Grad-CAM++
        p = self.pool(f).flatten(1)
        p = self.dropout(p)
        logits = self.fc(p) / self.temperature
        evidence = self.unc_head(p)
        return f, logits, evidence

# ---------------------------
# 4. Helper Functions
# ---------------------------
def compute_uncertainty(evidence):
    alpha = evidence + 1
    S = alpha.sum(dim=1, keepdim=True)
    uncertainty = num_classes / S
    return uncertainty

def evidential_loss(logits, evidence, y):
    alpha = evidence + 1
    S = alpha.sum(dim=1, keepdim=True)
    p = alpha / S
    target_alpha = torch.zeros_like(alpha)
    target_alpha.scatter_(1, y.unsqueeze(1), 1)
    kl_div = torch.sum(target_alpha * (torch.log(target_alpha + 1e-10) - torch.log(p + 1e-10)), dim=1)
    return kl_div.mean()

def optimize_temperature(model, val_loader, device):
    model.eval()
    nll_criterion = nn.CrossEntropyLoss()
    # Tối ưu hóa chỉ tham số temperature
    optimizer = optim.LBFGS([model.temperature], lr=0.01, max_iter=50, history_size=100)

    def closure():
        optimizer.zero_grad()
        loss = 0
        for x, y, _ in val_loader: # Val loader vẫn chứa ảnh đầy đủ
            x, y = x.to(device), y.to(device)
            with autocast():
                logits, _ = model(x)
            loss += nll_criterion(logits, y)
        loss.backward()
        return loss

    print(f"Optimizing temperature (initial: {model.temperature.item():.3f})...")
    optimizer.step(closure)
    print(f"Optimized temperature: {model.temperature.item():.3f}")

def compute_nonconformity(model, loader, device, is_features_loader=False):
    model.eval()
    scores = []
    with torch.no_grad():
        for x, y, _ in loader:
            x, y = x.to(device), y.to(device)
            with autocast():
                # Dựa vào input type để gọi model
                if is_features_loader:
                    logits, _ = model(x) # x đã là features
                else:
                    logits, _ = model(x) # x là ảnh
            probs = F.softmax(logits, dim=1)
            # score = 1 - P(y_true)
            scores.extend([1 - probs[i, y[i]].item() for i in range(len(y))])
            torch.cuda.empty_cache()
    return np.array(scores)

def set_conformal_threshold(scores, alpha=alpha_conf):
    q = np.quantile(scores, 1 - alpha)
    return q

def grad_cam_plus_plus(model, img_tensor, target_class):
    # Ensure model is in eval mode and gradients are enabled for features
    model.eval()
    # Temporarily enable gradients for backbone.features if they were frozen
    for param in model.backbone.features.parameters():
        param.requires_grad = True

    # Use a dummy input to clear any previous hooks/grads
    _ = model.forward_with_features(torch.zeros_like(img_tensor).unsqueeze(0).to(device))
    
    # Clear previous gradients
    if model.feature_grad is not None:
        model.feature_grad = None

    img_input = img_tensor.unsqueeze(0).to(device)
    img_input.requires_grad_(True) # Cần gradient cho input để hook vào feature map

    with autocast():
        fmap, logits, _ = model.forward_with_features(img_input)
    
    # Zero out grads for logits first if needed
    model.zero_grad() 

    score = logits[0, target_class]
    score.backward()

    # Get the gradient from the hook
    grad = model.feature_grad
    if grad is None:
        print("Warning: feature_grad is None. Grad-CAM++ might not work as expected.")
        # Fallback or handle error
        return np.zeros((fmap.shape[2], fmap.shape[3])) # Return empty CAM

    # Grad-CAM++ specific calculations
    grad2 = grad ** 2
    grad3 = grad ** 3
    alpha_num = grad2
    alpha_denom = grad2 * fmap + grad3 * fmap ** 2
    alpha_denom = torch.where(alpha_denom != 0, alpha_denom, torch.ones_like(alpha_denom)) # Avoid division by zero
    alpha = alpha_num / alpha_denom

    weights = torch.sum(alpha * torch.relu(grad), dim=(2, 3), keepdim=True)
    cam = torch.sum(weights * fmap, dim=1)
    cam = F.relu(cam) # Apply ReLU to CAM

    # Normalize CAM to 0-1
    cam = cam - cam.min()
    cam = cam / (cam.max() + 1e-10) # Avoid division by zero

    # Reset backbone.features gradients to original state (frozen)
    for param in model.backbone.features.parameters():
        param.requires_grad = False # Freeze lại nếu ban đầu bị đóng băng

    return cam.cpu().numpy()[0]


def patch_uncertainty(fmap, patch_size=16):
    # Ensure fmap is in float for variance calculation
    fmap = fmap.float()
    B, C, H, W = fmap.shape
    # Sử dụng F.unfold thay vì unfold method để linh hoạt hơn
    # Unfold along height (dim 2)
    fmap_unfolded_h = F.unfold(fmap, kernel_size=(patch_size, 1), stride=(patch_size, 1))
    # Unfold along width (dim 3)
    fmap_unfolded_w = F.unfold(fmap_unfolded_h.transpose(2,3), kernel_size=(patch_size, 1), stride=(patch_size, 1))
    fmap_patches = fmap_unfolded_w.transpose(2,3) # (B, C * patch_size * patch_size, num_patches)

    # Reshape to (B, C, num_patches_h, num_patches_w, patch_size, patch_size)
    fmap_patches = fmap_patches.view(B, C, H // patch_size, W // patch_size, patch_size, patch_size)

    # Calculate variance across the patch dimensions (last two)
    patch_var = fmap_patches.var(dim=(-2, -1)).mean(dim=1) # Mean over channels as well
    return patch_var.view(-1, H // patch_size, W // patch_size)


def generate_heatmaps(model, img_tensor, target_class):
    # Grad-CAM++ needs gradients from backbone
    cam = grad_cam_plus_plus(model, img_tensor, target_class)
    
    # For uncertainty map, we need features, but no gradients are needed for UQ head's forward pass
    model.eval() # Set model to eval mode for consistent behavior
    with torch.no_grad(): # Ensure no gradients for uncertainty map generation
        with autocast():
            # Call forward_with_features for UQ map, but disable gradient tracking for it
            fmap, _, evidence = model.forward_with_features(img_tensor.unsqueeze(0).to(device))
    
    unc_map = patch_uncertainty(fmap).cpu().numpy()[0]
    unc_map = (unc_map - unc_map.min()) / (unc_map.max() + 1e-10)
    return cam, unc_map

# ---------------------------
# 5. Train/Validate Functions
# ---------------------------
def train_one_epoch(model, loader, cls_crit, opt, scaler, accumulation_steps, is_features_loader=False):
    model.train()
    tot_loss = tot_cls = tot_unc = 0; correct = 0; N = 0
    opt.zero_grad()
    for i, (x, y, _) in enumerate(loader): # _ for paths
        x, y = x.to(device), y.to(device)
        with autocast():
            # Dựa vào input type để gọi model
            if is_features_loader:
                logits, evidence = model(x) # x đã là features
            else:
                logits, evidence = model(x) # x là ảnh

            l_cls = cls_crit(logits, y)
            l_unc = evidential_loss(logits, evidence, y)
            loss = (l_cls + lambda_unc * l_unc) / accumulation_steps
        scaler.scale(loss).backward()
        if (i + 1) % accumulation_steps == 0:
            scaler.step(opt)
            scaler.update()
            opt.zero_grad()
        bs = x.size(0); N += bs
        tot_loss += loss.item() * bs * accumulation_steps
        tot_cls += l_cls.item() * bs
        tot_unc += l_unc.item() * bs
        correct += (logits.argmax(1) == y).sum().item()
        torch.cuda.empty_cache()
    return tot_loss / N, tot_cls / N, tot_unc / N, correct / N

def validate_one_epoch(model, loader, cls_crit, tau_conf):
    model.eval()
    tot_loss = 0; N = 0
    all_conf = []; all_unc = []; all_true = []; all_pred = []; all_paths = []
    with torch.no_grad():
        for x, y, p in loader:
            x, y = x.to(device), y.to(device)
            with autocast():
                logits, evidence = model(x) # Validate luôn dùng ảnh gốc để tính
            probs = F.softmax(logits, dim=1); conf, _ = probs.max(1)
            unc = compute_uncertainty(evidence).squeeze()
            l_cls = cls_crit(logits, y)
            tot_loss += l_cls.item() * x.size(0); N += x.size(0)
            all_conf.append(np.atleast_1d(conf.cpu().numpy()))
            all_unc.append(np.atleast_1d(unc.cpu().numpy()))
            all_true.append(y.cpu().numpy())
            all_pred.append(logits.argmax(1).cpu().numpy())
            all_paths.extend(p)
            torch.cuda.empty_cache()

    confs = np.concatenate(all_conf); uncs = np.concatenate(all_unc)
    trues = np.concatenate(all_true); preds = np.concatenate(all_pred)
    
    # Calculate metrics for LtR
    rej = confs < tau_conf # Reject if confidence is below threshold
    keep = ~rej
    cov = keep.mean()
    
    # Handle cases where no samples are kept
    corr_kept = ((preds == trues) & keep).sum()
    risk = 1 - (corr_kept / keep.sum() if keep.sum() > 0 else 0.0) # Risk on kept samples
    sel_acc = (corr_kept / keep.sum() if keep.sum() > 0 else 0.0) # Selective accuracy
    
    correct_rej = ((preds != trues) & rej).sum() # Correctly rejected (model was wrong and rejected)
    false_rej = ((preds == trues) & rej).sum()   # Falsely rejected (model was right but rejected)
    missed_rej = ((preds != trues) & keep).sum() # Missed rejections (model was wrong but kept)

    # Limit heatmap generation to a few rejected samples for visualization
    rejected_indices = np.where(rej)[0]
    # Filter to only a few samples for practical visualization
    if len(rejected_indices) > 0:
        sample_indices_for_heatmaps = random.sample(list(rejected_indices), min(len(rejected_indices), 5)) # Limit to 5
        print(f"Generating heatmaps for {len(sample_indices_for_heatmaps)} rejected samples...")
        for i_original_dataset in sample_indices_for_heatmaps:
            img_tensor = loader.dataset[i_original_dataset][0]
            predicted_class = preds[i_original_dataset] # Use model's predicted class for CAM
            
            try:
                cam, unc_map = generate_heatmaps(model, img_tensor, predicted_class)
                original_filename = os.path.basename(all_paths[i_original_dataset]).split('.')[0]
                plt.imsave(os.path.join(checkpoint_dir, f"heatmap_cam_{original_filename}.png"), cam, cmap='jet')
                plt.imsave(os.path.join(checkpoint_dir, f"heatmap_unc_{original_filename}.png"), unc_map, cmap='jet')
                print(f"Saved heatmaps for {original_filename}")
            except Exception as e:
                print(f"Error generating heatmaps for {all_paths[i_original_dataset]}: {e}")
            torch.cuda.empty_cache() # Clear cache after each heatmap generation

    return {
        'loss': tot_loss / N,
        'coverage': cov,
        'selective_acc': sel_acc,
        'risk': risk,
        'correct_rejects': int(correct_rej),
        'false_rejects': int(false_rej),
        'missed_rejects': int(missed_rej),
        'confidences': confs, 'uncertainties': uncs,
        'true_labels': trues, 'predictions': preds,
        'paths': all_paths
    }

# ----------------------------------------
# NEW: Function to extract features for Phase 2
# ----------------------------------------
def extract_features(model, loader, device):
    model.eval()
    all_features = []
    all_labels = []
    all_paths = []
    print("Extracting features for Phase 2 training...")
    with torch.no_grad():
        for i, (x, y, p) in enumerate(loader):
            x = x.to(device)
            with autocast():
                # Forward pass through backbone only
                f = model.backbone.features(x)
                pooled_f = model.pool(f).flatten(1)
            all_features.append(pooled_f.cpu())
            all_labels.append(y.cpu())
            all_paths.extend(p)
            torch.cuda.empty_cache()
            if (i+1) % 100 == 0:
                print(f"Processed {i+1} batches for feature extraction.")
    print("Feature extraction complete.")
    return torch.cat(all_features, dim=0), torch.cat(all_labels, dim=0), all_paths

# ---------------------------
# 6. Run Training
# ---------------------------
def run_training():
    # DataLoaders with reduced num_workers
    train_tf = Compose([Resize(224,224), RandomBrightnessContrast(0.2,0.2,p=0.5),
                         Rotate(40,p=0.5), HorizontalFlip(p=0.5),
                         Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
                         ToTensorV2()])
    test_tf  = Compose([Resize(224,224), Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]), ToTensorV2()])
    
    train_ds = CTScanDataset(train_pos, train_neg, train_tf)
    test_ds  = CTScanDataset(test_pos, test_neg, test_tf) # Test set vẫn cần ảnh gốc cho validate
    
    train_ld_p1 = DataLoader(train_ds, batch_size_p1, shuffle=True, num_workers=2, pin_memory=True)
    test_ld     = DataLoader(test_ds, batch_size_p1, shuffle=False, num_workers=2, pin_memory=True) # Test set luôn dùng ảnh gốc

    # Model & criterions
    model = MobileNetV2WithUncertainty().to(device)
    cls_crit = nn.CrossEntropyLoss()
    scaler = GradScaler()

    # --- Phase 1: Train classifier ---
    print("\n--- Starting Phase 1: Training Classifier ---")
    # 1. Thu thập các tham số của unc_head vào một set để kiểm tra nhanh hơn
    unc_head_params = set(model.unc_head.parameters())
    
    # 2. Đóng băng uncertainty head trong Pha 1
    for p in unc_head_params: # Chỉ đóng băng các tham số CỦA unc_head
        p.requires_grad = False
    
    # 3. Mở đóng băng cho các tham số khác (backbone và fc layer)
    #    Để đảm bảo chỉ các phần mong muốn được huấn luyện, ta reset require_grad cho TẤT CẢ
    #    các tham số trước, sau đó chỉ mở cho các phần mong muốn.
    for p in model.parameters():
        p.requires_grad = True # Mặc định mở lại tất cả
    
    #    Sau đó đóng băng lại những phần đã xác định là không huấn luyện
    for p in unc_head_params:
        p.requires_grad = False # Đóng băng lại unc_head
    
    # Khởi tạo optimizer chỉ với các tham số yêu cầu gradient
    opt1 = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate_p1)
    
    best_selective_acc_p1 = 0 # Metric để lưu model tốt nhất cho P1
    best_w1 = copy.deepcopy(model.state_dict())
    
    for e in range(num_epochs_p1):
        t_loss, t_cls, t_unc, t_acc = train_one_epoch(model, train_ld_p1, cls_crit, opt1, scaler, accumulation_steps, is_features_loader=False)
        val = validate_one_epoch(model, test_ld, cls_crit, 0.5) # Dùng 0.5 tạm thời để đánh giá
        print(f"[P1][Epoch {e+1}/{num_epochs_p1}] Train Loss: {t_loss:.4f} (CLS: {t_cls:.4f}, UNC: {t_unc:.4f}) Acc: {t_acc:.3f} | Val Cov: {val['coverage']:.3f} Sel Acc: {val['selective_acc']:.3f} Risk: {val['risk']:.3f}")
    
        # Lưu model tốt nhất dựa trên Selective Accuracy trên tập validation
        if val['selective_acc'] > best_selective_acc_p1:
            best_selective_acc_p1 = val['selective_acc']
            best_w1 = copy.deepcopy(model.state_dict())
            print(f"    New best P1 model saved with Sel Acc: {best_selective_acc_p1:.3f}")
    
    model.load_state_dict(best_w1)
    print("Phase 1 training complete. Loaded best model.")
    torch.cuda.empty_cache() # Giải phóng bộ nhớ sau Pha 1

    # --- Optimize Temperature ---
    # Tối ưu hóa nhiệt độ sau khi Pha 1 hoàn thành để hiệu chuẩn tốt hơn
    optimize_temperature(model, test_ld, device)
    torch.cuda.empty_cache()

    # --- Feature Extraction for Phase 2 ---
    # Sau khi P1 hoàn thành, trích xuất features từ tập train cho P2
    train_features, train_labels, train_paths_features = extract_features(model, train_ld_p1, device)
    # Tạo DataLoader mới cho Pha 2 từ các feature
    train_ds_p2 = TensorDataset(train_features, train_labels, torch.tensor([i for i in range(len(train_labels))])) # Sử dụng index thay vì path tạm thời
    # Tạo một DataLoader mới để truyền paths nếu bạn muốn, nhưng để đơn giản cho TensorDataset thì dùng index hoặc bỏ qua
    train_ld_p2 = DataLoader(train_ds_p2, batch_size_p2, shuffle=True, num_workers=2, pin_memory=True)
    
    del train_features, train_labels # Giải phóng bộ nhớ của các tensor features không cần thiết nữa
    torch.cuda.empty_cache()

    # --- Phase 2: Train only Uncertainty Head ---
    print("\n--- Starting Phase 2: Training Uncertainty Head ---")
    # Đóng băng toàn bộ mô hình
    for p in model.parameters():
        p.requires_grad = False
    # Chỉ mở đóng băng cho uncertainty head
    for p in model.unc_head.parameters():
        p.requires_grad = True
    
    # Tối ưu hóa chỉ các tham số của unc_head
    opt2 = optim.Adam(model.unc_head.parameters(), lr=learning_rate_p2)
    
    best_risk_p2 = float('inf') # Tối ưu hóa Risk trong Pha 2
    best_w2 = copy.deepcopy(model.state_dict())

    for e in range(num_epochs_p2):
        t_loss, t_cls, t_unc, t_acc = train_one_epoch(model, train_ld_p2, cls_crit, opt2, scaler, accumulation_steps, is_features_loader=True)
        val = validate_one_epoch(model, test_ld, cls_crit, 0.5) # Dùng 0.5 để đánh giá tạm thời
        print(f"[P2][Epoch {e+1}/{num_epochs_p2}] Train Loss: {t_loss:.4f} (CLS: {t_cls:.4f}, UNC: {t_unc:.4f}) Acc: {t_acc:.3f} | Val Cov: {val['coverage']:.3f} Sel Acc: {val['selective_acc']:.3f} Risk: {val['risk']:.3f}")
        
        # Lưu model tốt nhất dựa trên Risk trên tập validation (hoặc một metric khác bạn quan tâm cho LtR)
        # Mục tiêu của LtR là giảm Risk trên các mẫu được giữ lại
        if val['risk'] < best_risk_p2: # Smaller risk is better
            best_risk_p2 = val['risk']
            best_w2 = copy.deepcopy(model.state_dict())
            print(f"    New best P2 model saved with Risk: {best_risk_p2:.3f}")
            
    model.load_state_dict(best_w2)
    print("Phase 2 training complete. Loaded best model.")
    torch.cuda.empty_cache()

    # --- Calibrate Conformal Threshold ---
    # Tính nonconformity scores trên tập validation (ảnh gốc)
    scores = compute_nonconformity(model, test_ld, device, is_features_loader=False)
    tau_conf = set_conformal_threshold(scores, alpha_conf)
    print(f"Calibrated conformal threshold: tau_conf={tau_conf:.3f}")

    # --- Final Validation with Calibrated Threshold ---
    print("\n--- Performing Final Validation ---")
    final_val_results = validate_one_epoch(model, test_ld, cls_crit, tau_conf)
    print(f"Final Validation Metrics:")
    print(f"  Coverage: {final_val_results['coverage']:.3f}")
    print(f"  Selective Accuracy: {final_val_results['selective_acc']:.3f}")
    print(f"  Risk on Kept Samples: {final_val_results['risk']:.3f}")
    print(f"  Correct Rejects: {final_val_results['correct_rejects']}")
    print(f"  False Rejects: {final_val_results['false_rejects']}")
    print(f"  Missed Rejects: {final_val_results['missed_rejects']}")

    # Save final model
    torch.save({'state': model.state_dict(), 'tau': tau_conf}, os.path.join(checkpoint_dir, "final_model_with_tau.pth"))
    print(f"Final model and conformal threshold saved to {os.path.join(checkpoint_dir, 'final_model_with_tau.pth')}")
    
    return model, tau_conf

# ---------------------------
# Main
# ---------------------------
if __name__ == "__main__":
    final_model, tc = run_training()
    print("Done. Conformal Threshold:", tc)

  check_for_updates()



--- Starting Phase 1: Training Classifier ---


  scaler = GradScaler()
  with autocast():
  with autocast():


[P1][Epoch 1/10] Train Loss: 0.7232 (CLS: 0.4054, UNC: 0.6356) Acc: 0.818 | Val Cov: 1.000 Sel Acc: 0.938 Risk: 0.062
    New best P1 model saved with Sel Acc: 0.938
[P1][Epoch 2/10] Train Loss: 0.5403 (CLS: 0.2787, UNC: 0.5231) Acc: 0.889 | Val Cov: 1.000 Sel Acc: 0.952 Risk: 0.048
    New best P1 model saved with Sel Acc: 0.952
[P1][Epoch 3/10] Train Loss: 0.4631 (CLS: 0.2282, UNC: 0.4698) Acc: 0.908 | Val Cov: 1.000 Sel Acc: 0.946 Risk: 0.054
[P1][Epoch 4/10] Train Loss: 0.4229 (CLS: 0.2005, UNC: 0.4449) Acc: 0.922 | Val Cov: 1.000 Sel Acc: 0.960 Risk: 0.040
    New best P1 model saved with Sel Acc: 0.960
[P1][Epoch 5/10] Train Loss: 0.4012 (CLS: 0.1867, UNC: 0.4290) Acc: 0.932 | Val Cov: 1.000 Sel Acc: 0.970 Risk: 0.030
    New best P1 model saved with Sel Acc: 0.970
[P1][Epoch 6/10] Train Loss: 0.3745 (CLS: 0.1666, UNC: 0.4158) Acc: 0.943 | Val Cov: 1.000 Sel Acc: 0.980 Risk: 0.020
    New best P1 model saved with Sel Acc: 0.980
[P1][Epoch 7/10] Train Loss: 0.3679 (CLS: 0.1618, UN

  with autocast():


In [None]:
import os
import glob
import random

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from torchvision import models
from albumentations import Compose, Resize, Normalize
from albumentations.pytorch import ToTensorV2

# ---------------------------
# 0. Thiết lập chung
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Đường dẫn tới model checkpoint đã train xong
CHECKPOINT_PATH = "./checkpoints/final_model.pth"

# Thư mục chứa ảnh test (COVID / non-COVID)
TEST_POS_FOLDER = "/kaggle/input/sarscov2-ctscan-dataset/COVID"
TEST_NEG_FOLDER = "/kaggle/input/sarscov2-ctscan-dataset/non-COVID"

# Ngưỡng “reject” (sẽ được truyền vào từ training hoặc file)
# Modified: Thresholds will be passed as arguments or loaded
TAU_CONF = 0.6  # Placeholder, to be updated
TAU_UNC  = 0.05  # Placeholder, to be updated
BETA     = 0.5

# ---------------------------
# 1. Định nghĩa Model MobileNetV2WithUncertainty
# ---------------------------
class MobileNetV2WithUncertainty(nn.Module):
    def __init__(self, num_classes=2, p_dropout=0.5):  # Modified: Increased dropout rate to 0.5
        super(MobileNetV2WithUncertainty, self).__init__()
        backbone = models.mobilenet_v2(pretrained=True)
        for idx, layer in enumerate(backbone.features):
            if idx % 5 == 4:
                backbone.features[idx].add_module("dropout", nn.Dropout(p=p_dropout))

        self.backbone = backbone
        self.backbone.classifier = nn.Identity()
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=p_dropout)

        self.fc = nn.Linear(self.backbone.last_channel, num_classes)
        self.uncertainty_head = nn.Sequential(
            nn.Linear(self.backbone.last_channel, 128),
            nn.ReLU(),
            nn.Dropout(p=p_dropout),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        self.feature_grad = None

    def save_feature_grad(self, grad):
        self.feature_grad = grad

    def forward(self, x):
        features = self.backbone.features(x)
        pooled = self.pool(features)
        flat = torch.flatten(pooled, 1)
        flat = self.dropout(flat)

        logits = self.fc(flat)
        u_pred = self.uncertainty_head(flat)
        return logits, u_pred

    def forward_with_features(self, x):
        features = self.backbone.features(x)
        if features.requires_grad:
            features.register_hook(self.save_feature_grad)

        fmap = features.clone()
        pooled = self.pool(features)
        flat = torch.flatten(pooled, 1)
        flat = self.dropout(flat)

        logits = self.fc(flat)
        u_pred = self.uncertainty_head(flat)
        return fmap, logits, u_pred


# ---------------------------
# 2. Hàm compute_gradcam & compute_uncertainty_cam
# ---------------------------
def compute_gradcam(model, img_tensor):
    model.eval()
    fmap, logits, _ = model.forward_with_features(img_tensor.unsqueeze(0).to(device))
    probs = F.softmax(logits, dim=1)[0]
    class_idx = torch.argmax(probs).item()
    top_conf = probs[class_idx].item()

    model.zero_grad()
    score = logits[0, class_idx]
    score.backward(retain_graph=True)

    grads = model.feature_grad[0]
    fmap_data = fmap[0]
    C, H, W = fmap_data.shape
    weights = grads.view(C, -1).mean(dim=1)

    cam = torch.zeros((H, W), dtype=torch.float32, device=fmap_data.device)
    for k in range(C):
        cam += weights[k] * fmap_data[k]
    cam = torch.relu(cam).detach().cpu().numpy()

    cam -= cam.min()
    cam /= (cam.max() + 1e-8)

    _, orig_h, orig_w = img_tensor.shape
    cam_resized = cv2.resize(cam, (orig_w, orig_h))
    return class_idx, top_conf, cam_resized


def compute_uncertainty_cam(model, img_tensor):
    model.eval()
    fmap, _, u_pred = model.forward_with_features(img_tensor.unsqueeze(0).to(device))
    u_global = u_pred.item()

    model.zero_grad()
    u_pred.backward(retain_graph=True)

    grads = model.feature_grad[0]
    fmap_data = fmap[0]
    C, H, W = fmap_data.shape
    weights = grads.view(C, -1).mean(dim=1)

    cam = torch.zeros((H, W), dtype=torch.float32, device=fmap_data.device)
    for k in range(C):
        cam += weights[k] * fmap_data[k]
        
    cam = torch.relu(cam).detach().cpu().numpy()

    cam -= cam.min()
    cam /= (cam.max() + 1e-8)

    _, orig_h, orig_w = img_tensor.shape
    cam_resized = cv2.resize(cam, (orig_w, orig_h))
    return u_global, cam_resized


# ---------------------------
# 3. Hàm explain & reject
# ---------------------------
def explain_ct_reject(model, img_tensor, tau_conf=TAU_CONF, tau_unc=TAU_UNC):  # Modified: Accept tau_conf, tau_unc as parameters
    class_idx, top_conf, gradcam = compute_gradcam(model, img_tensor)
    u_global, unc_cam = compute_uncertainty_cam(model, img_tensor)

    if (top_conf < tau_conf) and (u_global > tau_unc):
        composite = BETA * gradcam + (1 - BETA) * unc_cam
        composite -= composite.min()
        composite /= (composite.max() + 1e-8)
        return {
            "reject": True,
            "class_idx": class_idx,
            "top_conf": top_conf,
            "u_global": u_global,
            "gradcam": gradcam,
            "uncertainty_cam": unc_cam,
            "composite_cam": composite
        }
    else:
        return {
            "reject": False,
            "class_idx": class_idx,
            "top_conf": top_conf,
            "u_global": u_global
        }


# ---------------------------
# 4. Hàm visualize overlay (trên ảnh 224x224)
# ---------------------------
def visualize_explanation(img_resized: np.ndarray, gradcam: np.ndarray,
                         unc_cam: np.ndarray, composite: np.ndarray):
    """
    img_resized: numpy array (224x224x3) không normalize, dtype uint8
    Các heatmap: đều có shape (224,224) và giá trị trong [0,1]
    """
    img_rgb = img_resized.copy()
    if img_rgb.ndim == 2:
        img_rgb = np.stack([img_rgb]*3, axis=-1)
    # img_rgb chắc chắn đã có dạng 224x224x3 uint8

    def overlay_heatmap(base, heatmap, cmap='jet', alpha=0.5):
        heatmap_color = plt.get_cmap(cmap)(heatmap)[..., :3]  # (224,224,3) float [0,1]
        base_float = base.astype(np.float32)
        overlay = base_float * (1 - alpha) + heatmap_color * 255 * alpha
        return overlay.astype(np.uint8)

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    axes[0].imshow(img_rgb)
    axes[0].set_title('Ảnh CT (224×224)'); axes[0].axis('off')

    axes[1].imshow(overlay_heatmap(img_rgb, gradcam, cmap='jet', alpha=0.5))
    axes[1].set_title('Grad-CAM'); axes[1].axis('off')

    axes[2].imshow(overlay_heatmap(img_rgb, unc_cam, cmap='hot', alpha=0.5))
    axes[2].set_title('Uncertainty-CAM'); axes[2].axis('off')

    axes[3].imshow(overlay_heatmap(img_rgb, composite, cmap='jet', alpha=0.5))
    axes[3].set_title('Composite'); axes[3].axis('off')

    plt.tight_layout()
    plt.show()


# ---------------------------
# 5. Preprocess ảnh (Albumentations test)
# ---------------------------
test_transform = Compose([
    Resize(224, 224),
    Normalize(mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

def load_and_preprocess(path):
    """
    Trả về:
      - img_np_orig: numpy array RGB (H_orig x W_orig x 3), dtype uint8
      - img_t: torch.Tensor (3 x 224 x 224), normalized để inference
      - img_np_resized: numpy array RGB (224 x 224 x 3), dtype uint8
    """
    # 1) Đọc ảnh gốc
    img_bgr = cv2.imread(path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_np_orig = img_rgb.copy()

    # 2) Resize sang 224x224 để overlay
    img_np_resized = cv2.resize(img_rgb, (224, 224))

    # 3) Áp dụng Albumentations normalize + ToTensor (không áp dụng augmentation khác)
    augmented = test_transform(image=img_np_resized)
    img_t = augmented["image"]  # Tensor (3,224,224)

    return img_np_orig, img_t, img_np_resized


# ---------------------------
# 6. Hàm inference toàn tập Test
# ---------------------------
def run_full_test_inference(model, tau_conf=TAU_CONF, tau_unc=TAU_UNC):  # Modified: Accept tau_conf, tau_unc as parameters
    # Lấy tất cả ảnh test
    pos_paths = glob.glob(os.path.join(TEST_POS_FOLDER, "*.png")) + \
                glob.glob(os.path.join(TEST_POS_FOLDER, "*.jpg"))
    neg_paths = glob.glob(os.path.join(TEST_NEG_FOLDER, "*.png")) + \
                glob.glob(os.path.join(TEST_NEG_FOLDER, "*.jpg"))
    all_test_paths = pos_paths + neg_paths

    # Xây dictionary true label
    true_labels = {}
    for p in pos_paths:
        true_labels[p] = 0
    for p in neg_paths:
        true_labels[p] = 1

    total = len(all_test_paths)
    rejected_count = 0
    correct_rejections = 0
    non_rejected_correct = 0
    non_rejected_total = 0
    misclassified_list = []
    correct_rejection_list = []

    print(f"Tổng số ảnh test: {total}")

    for idx, img_path in enumerate(all_test_paths):
        img_np_orig, img_t, img_np_resized = load_and_preprocess(img_path)
        result = explain_ct_reject(model, img_t, tau_conf, tau_unc)  # Modified: Pass tau_conf, tau_unc
        true_label = true_labels[img_path]

        if result["reject"]:
            rejected_count += 1
            if result["class_idx"] != true_label:
                # Correct rejection: misclassified and rejected
                correct_rejections += 1
                correct_rejection_list.append({
                    "path": img_path,
                    "true": true_label,
                    "pred": result["class_idx"],
                    "conf": result["top_conf"],
                    "unc": result["u_global"],
                    "reject": True,
                    "gradcam": result["gradcam"],
                    "uncertainty_cam": result["uncertainty_cam"],
                    "composite_cam": result["composite_cam"],
                    "img_np_resized": img_np_resized,
                    "img_np_orig": img_np_orig
                })
            if result["class_idx"] != true_label:
                misclassified_list.append({
                    "path": img_path,
                    "true": true_label,
                    "pred": result["class_idx"],
                    "conf": result["top_conf"],
                    "unc": result["u_global"],
                    "reject": True,
                    "gradcam": result["gradcam"],
                    "uncertainty_cam": result["uncertainty_cam"],
                    "composite_cam": result["composite_cam"],
                    "img_np_resized": img_np_resized,
                    "img_np_orig": img_np_orig
                })
        else:
            non_rejected_total += 1
            if result["class_idx"] == true_label:
                non_rejected_correct += 1
            else:
                misclassified_list.append({
                    "path": img_path,
                    "true": true_label,
                    "pred": result["class_idx"],
                    "conf": result["top_conf"],
                    "unc": result["u_global"],
                    "reject": False,
                    "img_np_orig": img_np_orig,
                    "img_np_resized": img_np_resized
                })

        if (idx + 1) % 100 == 0 or (idx + 1) == total:
            print(f"  Đã chạy inference {idx+1}/{total} ảnh → Rejected: {rejected_count}")

    # Calculate accuracies
    non_rejected_accuracy = (non_rejected_correct / non_rejected_total * 100) if non_rejected_total > 0 else 0
    rejection_accuracy = (correct_rejections / rejected_count * 100) if rejected_count > 0 else 0

    # Kết quả tổng hợp
    print(f"\nKẾT QUẢ:\n  Tổng số ảnh test            : {total}")
    print(f"  Tổng số ảnh bị rejected     : {rejected_count}")
    print(f"  Số ảnh reject đúng (misclassified và rejected): {correct_rejections}")
    print(f"  Tỉ lệ reject đúng           : {rejection_accuracy:.2f}%")
    print(f"  Số ảnh không rejected đúng  : {non_rejected_correct}/{non_rejected_total}")
    print(f"  Tỉ lệ chính xác không rejected: {non_rejected_accuracy:.2f}%")
    print(f"  Tổng số ảnh bị misclassified: {len(misclassified_list)}\n")

    # Hiển thị các case reject đúng
    if correct_rejection_list:
        print("=== CÁC CASE REJECT ĐÚNG (MISCLASSIFIED VÀ REJECTED) ===")
        for case in correct_rejection_list:
            img_path = case["path"]
            true_label = case["true"]
            pred = case["pred"]
            conf = case["conf"]
            unc = case["unc"]

            print(f"--- Reject đúng: {os.path.basename(img_path)} ---")
            print(f"  True Label = {true_label} | Predicted = {pred} | Conf = {conf:.3f} | Unc = {unc:.4f}")
            visualize_explanation(case["img_np_resized"],
                                 case["gradcam"],
                                 case["uncertainty_cam"],
                                 case["composite_cam"])

    # Hiển thị chi tiết các case misclassified
    if misclassified_list:
        print("=== CÁC CASE MISCLASSIFIED ===")
        for case in misclassified_list:
            img_path = case["path"]
            true_label = case["true"]
            pred = case["pred"]
            conf = case["conf"]
            unc = case["unc"]

            print(f"--- Sai nhãn: {os.path.basename(img_path)} ---")
            print(f"  True Label = {true_label} | Predicted = {pred} | Conf = {conf:.3f} | Unc = {unc:.4f}")

            if case["reject"]:
                visualize_explanation(case["img_np_resized"],
                                     case["gradcam"],
                                     case["uncertainty_cam"],
                                     case["composite_cam"])
            else:
                plt.figure(figsize=(4,4))
                plt.imshow(case["img_np_orig"])
                plt.title(f"Pred {pred} | Conf {conf:.3f}")
                plt.axis('off')
                plt.show()

    return rejected_count, misclassified_list, correct_rejections, non_rejected_accuracy, rejection_accuracy


# ---------------------------
# 7. Main: Load model và chạy Inference toàn tập Test
# ---------------------------
if __name__ == "__main__":
    # Modified: Use optimal thresholds from training (example values, replace with actual values from training)
    optimal_tau_conf = 0.75  # Example: Replace with value from training
    optimal_tau_unc = 0.03   # Example: Replace with value from training

    # Load model checkpoint
    model = MobileNetV2WithUncertainty(num_classes=2).to(device)
    if not os.path.exists(CHECKPOINT_PATH):
        raise FileNotFoundError(f"Không tìm thấy checkpoint tại {CHECKPOINT_PATH}")
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
    model.eval()
    print("✔️ Đã load model từ:", CHECKPOINT_PATH, "\n")
    print(f"Using thresholds: tau_conf={optimal_tau_conf:.2f}, tau_unc={optimal_tau_unc:.4f}\n")

    # Chạy inference trên toàn bộ test
    rejected_count, misclassified_list, correct_rejections, non_rejected_accuracy, rejection_accuracy = run_full_test_inference(
        model, tau_conf=optimal_tau_conf, tau_unc=optimal_tau_unc
    )

    print("\nHoàn tất inference toàn tập test.")
    print(f"Số lượng ảnh bị rejected     : {rejected_count}")
    print(f"Số lượng ảnh reject đúng     : {correct_rejections}")
    print(f"Tỉ lệ reject đúng            : {rejection_accuracy:.2f}%")
    print(f"Tỉ lệ chính xác không rejected: {non_rejected_accuracy:.2f}%")
    print(f"Số lượng ảnh bị misclassified: {len(misclassified_list)}")

In [None]:
import os
import glob
import json
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torchvision import models
from torch.utils.data import Dataset, DataLoader
from albumentations import Compose, Resize, Normalize
from albumentations.pytorch import ToTensorV2

# ---------------------------
# 0. Thiết lập chung
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CHECKPOINT_PATH = "./checkpoints/final.pth"
TEST_POS_FOLDER = "/kaggle/input/sarscov2-ctscan-dataset/COVID"
TEST_NEG_FOLDER = "/kaggle/input/sarscov2-ctscan-dataset/non-COVID"

# ---------------------------
# 1. Model definition (same as training script)
# ---------------------------
class MobileNetV2WithUncertainty(nn.Module):
    def __init__(self, num_classes=2, p_dropout=0.5):
        super().__init__()
        m = models.mobilenet_v2(pretrained=True)
        for idx,layer in enumerate(m.features):
            if idx % 5 == 4:
                layer.add_module("dropout", nn.Dropout(p_dropout))
        self.backbone = m
        self.backbone.classifier = nn.Identity()
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout(p_dropout)
        self.fc = nn.Linear(self.backbone.last_channel, num_classes)
        self.unc_head = nn.Sequential(
            nn.Linear(self.backbone.last_channel,128),
            nn.ReLU(),
            nn.Dropout(p_dropout),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
        self.feature_grad = None
    def save_grad(self, grad): self.feature_grad = grad
    def forward(self, x):
        f = self.backbone.features(x)
        p = self.pool(f).flatten(1)
        p = self.dropout(p)
        return self.fc(p), self.unc_head(p)
    def forward_with_features(self, x):
        f = self.backbone.features(x)
        if f.requires_grad: f.register_hook(self.save_grad)
        fmap = f.clone()
        p = self.pool(f).flatten(1)
        p = self.dropout(p)
        return fmap, self.fc(p), self.unc_head(p)

# ---------------------------
# 2. Inference-time transforms
# ---------------------------
test_transform = Compose([
    Resize(224,224),
    Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ToTensorV2()
])

def load_and_preprocess(path):
    img_bgr = cv2.imread(path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_resized = cv2.resize(img_rgb, (224,224))
    img_t = test_transform(image=img_resized)["image"]
    return img_rgb, img_resized, img_t

# ---------------------------
# 3. explain_ct_reject (same logic)
# ---------------------------
def compute_gradcam(model, img_tensor):
    model.zero_grad()
    fmap, logits, _ = model.forward_with_features(img_tensor.unsqueeze(0).to(device))
    probs = F.softmax(logits,1)[0]
    class_idx = probs.argmax().item()
    logits[0, class_idx].backward(retain_graph=True)
    grads = model.feature_grad[0]
    fmap  = fmap[0]
    weights = grads.view(grads.size(0), -1).mean(dim=1)
    cam = (weights.view(-1,1,1) * fmap).sum(0).relu().detach().cpu().numpy()
    cam = (cam - cam.min())/(cam.max()-cam.min()+1e-8)
    return class_idx, cam

def compute_uncertainty_cam(model, img_tensor):
    model.zero_grad()
    fmap, _, u_pred = model.forward_with_features(img_tensor.unsqueeze(0).to(device))
    u = u_pred.item()
    u_pred.backward(retain_graph=True)
    grads = model.feature_grad[0]
    fmap  = fmap[0]
    weights = grads.view(grads.size(0),-1).mean(dim=1)
    cam = (weights.view(-1,1,1) * fmap).sum(0).relu().detach().cpu().numpy()
    cam = (cam - cam.min())/(cam.max()-cam.min()+1e-8)
    return u, cam

def explain_ct_reject(model, img_tensor, tau_conf, tau_unc, beta=0.5):
    class_idx, gradcam = compute_gradcam(model, img_tensor)
    u, unc_cam = compute_uncertainty_cam(model, img_tensor)
    with torch.no_grad():
        logits,_ = model(img_tensor.unsqueeze(0).to(device))
        conf = F.softmax(logits,1)[0,class_idx].item()
    reject = (conf < tau_conf) and (u > tau_unc)
    result = {
        "reject": reject,
        "class_idx": class_idx,
        "conf": conf,
        "unc": u
    }
    if reject:
        # resize cams to 224x224
        cam1 = cv2.resize(gradcam, (224,224))
        cam2 = cv2.resize(unc_cam, (224,224))
        comp = beta*cam1 + (1-beta)*cam2
        result.update({
            "gradcam": cam1,
            "uncertainty_cam": cam2,
            "composite_cam": (comp - comp.min())/(comp.max()-comp.min()+1e-8)
        })
    return result

# ---------------------------
# 4. visualize
# ---------------------------
def overlay(base, heat, cmap, alpha=0.5):
    heatc = plt.get_cmap(cmap)(heat)[...,:3]
    return (base*(1-alpha) + heatc*255*alpha).astype(np.uint8)

def visualize_case(img, res):
    fig,ax = plt.subplots(1,4,figsize=(16,4))
    ax[0].imshow(img); ax[0].set_title("Orig"); ax[0].axis("off")
    ax[1].imshow(overlay(img,res["gradcam"],'jet')); ax[1].set_title("GradCAM"); ax[1].axis("off")
    ax[2].imshow(overlay(img,res["uncertainty_cam"],'hot')); ax[2].set_title("UncCAM"); ax[2].axis("off")
    ax[3].imshow(overlay(img,res["composite_cam"],'jet')); ax[3].set_title("Composite"); ax[3].axis("off")
    plt.tight_layout(); plt.show()

# ---------------------------
# 5. Full inference + metrics
# ---------------------------
def run_full_test(model, tau_conf, tau_unc):
    pos = glob.glob(os.path.join(TEST_POS_FOLDER,"*.png"))
    neg = glob.glob(os.path.join(TEST_NEG_FOLDER,"*.png"))
    paths = pos+neg
    trues = np.array([0]*len(pos) + [1]*len(neg))
    confs = []; uncs = []; preds = []
    for p in paths:
        _,img224,img_t = load_and_preprocess(p)
        res = explain_ct_reject(model, img_t, tau_conf, tau_unc)
        confs.append(res["conf"]); uncs.append(res["unc"]); preds.append(res["class_idx"])
    confs = np.array(confs); uncs = np.array(uncs); preds=np.array(preds)
    rej = (confs<tau_conf)&(uncs>tau_unc)
    keep=~rej
    total=len(paths)
    RejCnt=rej.sum()
    CorrRej = np.logical_and(rej, preds!=trues).sum()
    FalseRej= np.logical_and(rej, preds==trues).sum()
    KeptCorr= np.logical_and(keep, preds==trues).sum()
    KeptTot = keep.sum()
    sel_acc = KeptCorr/KeptTot if KeptTot>0 else 0
    rej_acc = CorrRej/RejCnt if RejCnt>0 else 0
    print(f"Total: {total}, Rejected: {RejCnt}, CorrectRejects: {CorrRej}, RejectAcc: {rej_acc:.2%}")
    print(f"Kept total: {KeptTot}, SelectiveAcc: {sel_acc:.2%}")
    # Display some correct rejects
    for i,p in enumerate(paths):
        if rej[i] and preds[i]!=trues[i]:
            _,img224,_ = load_and_preprocess(p)
            print("Correct reject:", os.path.basename(p))
            visualize_case(img224, explain_ct_reject(model, load_and_preprocess(p)[2], tau_conf, tau_unc))
    # Display some misclassified-kept
    for i,p in enumerate(paths):
        if keep[i] and preds[i]!=trues[i]:
            img_rgb,_,_ = load_and_preprocess(p)
            print("Misclassified-kept:", os.path.basename(p), "pred",preds[i],"true",trues[i])
            plt.figure(figsize=(4,4)); plt.imshow(img_rgb); plt.axis("off"); plt.show()

# ---------------------------
# 6. Main
# ---------------------------
if __name__ == "__main__":
    # Load checkpoint: expect {'state':..., 'tau':(tc,tu)}
    ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
    model = MobileNetV2WithUncertainty().to(device)
    model.load_state_dict(ckpt['state'])
    tau_conf, tau_unc = ckpt.get('tau', (0.85,0.04))
    print(f"Loaded model and thresholds: tau_conf={tau_conf:.2f}, tau_unc={tau_unc:.3f}")

    model.eval()
    run_full_test(model, tau_conf, tau_unc)
