In [1]:
import torch
import torch.nn as nn

def evaluate_mc_track_indices(
    model,
    branch,
    test_loader,
    device,
    tau=0.8,
    mc_samples=16,
    num_classes=10
):
    """
    Evaluate model with MC dropout, track for each true class the index of
    the misclassified sample with highest uncertainty.
    """
    model.eval()
    criterion = nn.CrossEntropyLoss()

    # best_mis[c] = (max_uncertainty, dataset_idx, pred_label)
    best_mis = {
        c: (float('-inf'), None, None)
        for c in range(num_classes)
    }

    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            # unpack batch
            if len(batch) == 3:
                images, labels, indices = batch
            else:
                images, labels = batch
                # fallback: compute dataset idx from batch_idx
                start = batch_idx * test_loader.batch_size
                indices = torch.arange(start, start + images.size(0))

            images = images.to(device)
            labels = labels.to(device)

            # forward → logits & uncertainty
            logits, uncertainty = model(
                images,
                branch=branch,
                inf_type="mc",
                out_type="logits",
                mc_samples=mc_samples,
                tau=tau
            )
            # logits: [B, C], uncertainty: [B]

            # compute loss & accuracy on mean logits
            loss = criterion(logits, labels)
            total_loss += loss.item() * images.size(0)

            preds = logits.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

            # for each sample in the batch
            for i in range(images.size(0)):
                true_c = labels[i].item()
                pred_c = preds[i].item()
                if pred_c != true_c:
                    ent = uncertainty[i].item()
                    # update if higher uncertainty
                    if ent > best_mis[true_c][0]:
                        best_mis[true_c] = (
                            ent,
                            indices[i].item(),
                            pred_c
                        )

    # final stats
    avg_loss = total_loss / total_samples
    accuracy = 100.0 * total_correct / total_samples

    # print per-class highest-uncertainty misclassification
    print("\nPer-class highest-uncertainty misclassifications:")
    for cls in range(num_classes):
        max_ent, idx, pred = best_mis[cls]
        if idx is not None:
            print(
                f" Class {cls:2d} → idx {idx:5d}, "
                f"pred={pred:2d}, uncertainty={max_ent:.4f}"
            )
        else:
            print(f" Class {cls:2d} → no misclassifications.")

    return avg_loss, accuracy


In [2]:
from yacs.config import CfgNode as CN
from base_model import BaseClassifier
from data import make_dataset

source_train_loader, target_train_loader, source_test_loader, target_test_loader = (
    make_dataset(
        source_dataset="mnist",
        target_dataset="usps",
        img_size=64,
        train_bs=256,
        eval_bs=256,
        num_workers=8,
    )
)
cfg = CN(new_allowed=True)
cfg.merge_from_file("configs/test.yaml")
model = BaseClassifier(
    backbone=cfg.model.backbone.type,
    in_dim=cfg.model.backbone.in_dim,
    hidden_dim=cfg.model.backbone.hidden_dim,
    out_dim=cfg.dataset.num_classes,
    num_res_blocks=cfg.model.backbone.num_res_blocks,
    imgsize=cfg.img_size,
    patch_size=cfg.model.patch_size,
    attribute_layers=cfg.model.attribute_layers,
    p_vr_src=cfg.model.source.vr_dropout,
    p_vr_tgt=cfg.model.target.vr_dropout,
    p_cls_src=cfg.model.source.cls_dropout,
    p_cls_tgt=cfg.model.target.cls_dropout,
)



In [4]:
device = torch.device("cuda:1")

ckpt = torch.load("runs/mnist_usps/9N92DG/da_best_96.81.pth")
model.load_state_dict(ckpt["model_state_dict"])
model.to(device)

avg_loss, accuracy = evaluate_mc_track_indices(
    model,
    branch="src",
    test_loader=source_test_loader,
    device=device,
    tau=0.8,
    mc_samples=16,
    num_classes=10
)


Per-class highest-uncertainty misclassifications:
 Class  0 → idx  4824, pred= 5, uncertainty=1.5758
 Class  1 → idx  6783, pred= 6, uncertainty=1.6255
 Class  2 → idx  1514, pred= 0, uncertainty=1.7356
 Class  3 → idx    18, pred= 9, uncertainty=1.5018
 Class  4 → idx   736, pred= 1, uncertainty=1.5186
 Class  5 → idx  3558, pred= 0, uncertainty=1.5661
 Class  6 → idx  3762, pred= 8, uncertainty=1.6977
 Class  7 → idx  7260, pred= 0, uncertainty=1.6738
 Class  8 → idx  3062, pred= 7, uncertainty=1.8552
 Class  9 → idx  4224, pred= 7, uncertainty=1.6810


In [5]:
from torch.utils.data import DataLoader, Subset
from data import StrongWeakAugDataset

source_dataset = "mnist"
target_dataset = "usps"
img_size=64
source_test_data = StrongWeakAugDataset(
        dataset_name=source_dataset, root="./data", img_size=img_size, train=False
    )
target_test_data = StrongWeakAugDataset(
    dataset_name=target_dataset, root="./data", img_size=img_size, train=False
)

vis_src_dataset = Subset(source_test_data, [4824,6783,1514,18,736,3558,3762,7260,3062,4224])
vis_src_loader = DataLoader(
    vis_src_dataset,
    batch_size=10, 
    shuffle=False,     
    num_workers=4, 
)

vis_tgt_dataset = Subset(target_test_data, [1096,1019, 1316,122,265,793,698,1118,198,1095])
vis_tgt_loader = DataLoader(
    vis_tgt_dataset,
    batch_size=10, 
    shuffle=False,     
    num_workers=4, 
)

In [7]:
device = torch.device("cuda:1")

import os

all_mnist_exp = os.listdir("runs/mnist_usps")
for exp in all_mnist_exp:
    all_ckpt = os.listdir(os.path.join("runs/mnist_usps", exp))
    for ckpt in all_ckpt:
        if ckpt.endswith(".pth"):
            print("Evaluating {} {}".format(exp, ckpt))
            model.load_state_dict(
                torch.load(os.path.join("runs/mnist_usps", exp, ckpt), map_location=device)['model_state_dict']
            )
            model.to(device)

            avg_loss, accuracy = evaluate_mc(
                model,
                branch="tgt",
                test_loader=target_test_loader,
                device=device,
                tau=0.8,
                mc_samples=16,
            )
            print("Average Loss: {:.4f}, Accuracy: {:.2f}%".format(avg_loss, accuracy))
            print("="*20)

Evaluating 0IYP3g da_best_95.86.pth
Average Loss: 1.5197, Accuracy: 95.76%
Evaluating 0IYP3g da_best_95.52.pth
Average Loss: 1.5349, Accuracy: 95.17%
Evaluating 0IYP3g da_best_96.41.pth
Average Loss: 1.5090, Accuracy: 96.36%
Evaluating 0IYP3g da_best_96.61.pth
Average Loss: 1.5033, Accuracy: 96.41%
Evaluating 0IYP3g da_best_96.11.pth
Average Loss: 1.5112, Accuracy: 96.01%
Evaluating 0IYP3g da_best_96.06.pth


KeyboardInterrupt: 