In [1]:
import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

In [2]:
def _load_coco(path):
    with open(path, "r") as f:
        return json.load(f)

In [3]:
def _build_cat_mapping(train_coco):
    cat_ids = sorted({c["id"] for c in train_coco["categories"]})
    catid_to_idx = {cid: i for i, cid in enumerate(cat_ids)}
    return catid_to_idx

In [4]:
class CocoObjectCropClassification(Dataset):
    def __init__(self, images_dir, coco_json_path, catid_to_idx, transform=None):
        self.images_dir = images_dir
        self.transform = transform
        self.catid_to_idx = catid_to_idx

        coco = _load_coco(coco_json_path)

        self.img_by_id = {img["id"]: img for img in coco["images"]}

        # Build samples from annotations
        self.samples = []
        bad_category = 0

        for ann in coco["annotations"]:
            cid = ann["category_id"]
            if cid not in self.catid_to_idx:
                bad_category += 1
                continue

            img_info = self.img_by_id.get(ann["image_id"])
            if img_info is None:
                continue

            self.samples.append({
                "file_name": img_info["file_name"],
                "bbox": ann["bbox"],  # [x,y,w,h]
                "label": self.catid_to_idx[cid],
            })

        if bad_category > 0:
            print(f"[WARN] Skipped {bad_category} annotations with category_id not in train mapping.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        img_path = os.path.join(self.images_dir, s["file_name"])

        image = Image.open(img_path).convert("RGB")

        x, y, w, h = s["bbox"]
        left   = int(round(x))
        top    = int(round(y))
        right  = int(round(x + w))
        bottom = int(round(y + h))

        # Clamp to image bounds (prevents crashes on weird bboxes)
        W, H = image.size
        left = max(0, min(left, W - 1))
        top = max(0, min(top, H - 1))
        right = max(left + 1, min(right, W))
        bottom = max(top + 1, min(bottom, H))

        crop = image.crop((left, top, right, bottom))
        label = s["label"]

        if self.transform:
            crop = self.transform(crop)

        return crop, label

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_root = "./coco"
test_dir = os.path.join(data_root, "test")

test_json = os.path.join(test_dir, "_annotations.coco.json")

image_size = 224
test_tf = T.Compose([
    T.Resize((image_size, image_size)),
    T.ToTensor(),
])
test_coco = _load_coco(test_json)
test_catid_to_idx = _build_cat_mapping(test_coco)

test_ds = CocoObjectCropClassification(test_dir, test_json, catid_to_idx=test_catid_to_idx, transform=test_tf)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=4)

print("Test samples:", len(test_ds))


Test samples: 424


In [6]:
import torch
import torch.nn as nn
from simclr import SimCLR
from simclr.modules import get_resnet

# ---- set these to what you trained with ----
resnet_name = "resnet50"
projection_dim = 64

class SimCLRFinetuneClassifier(nn.Module):
    def __init__(self, simclr_model, num_classes):
        super().__init__()
        self.simclr = simclr_model
        self.classifier = nn.Linear(self.simclr.n_features, num_classes)

    def forward(self, x):
        h, _, _, _ = self.simclr(x, x)
        return self.classifier(h)

ckpt_path = "/data2/gio/bobyard/finetuned_model/simclr_finetuned_epoch_19.pt"
checkpoint = torch.load(ckpt_path, map_location=device)

encoder = get_resnet(resnet_name, pretrained=False)
n_features = encoder.fc.in_features
simclr_model = SimCLR(encoder, projection_dim, n_features)

model = SimCLRFinetuneClassifier(simclr_model, checkpoint["num_classes"])
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
model.eval()

catid_to_idx = checkpoint["catid_to_idx"]
print(f"Number of classes: {len(catid_to_idx)}")



Number of classes: 42


In [7]:
import torch

def evaluate(loader, model):
    correct1 = 0
    correct5 = 0
    total = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)

            # top-1
            pred1 = logits.argmax(dim=1)
            correct1 += (pred1 == y).sum().item()

            # top-5 (safe if num_classes < 5)
            k = min(5, logits.size(1))
            topk = logits.topk(k, dim=1).indices
            correct5 += (topk == y.unsqueeze(1)).any(dim=1).sum().item()

            total += y.size(0)

            all_preds.append(pred1.cpu())
            all_labels.append(y.cpu())

    acc1 = correct1 / total
    acc5 = correct5 / total
    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()
    return acc1, acc5, all_labels, all_preds

acc1, acc5, y_true, y_pred = evaluate(test_loader, model)
print(f"Test Top-1 Acc: {acc1:.4f}")
print(f"Test Top-5 Acc: {acc5:.4f}")


Test Top-1 Acc: 0.9835
Test Top-5 Acc: 0.9929


In [8]:
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

cm = confusion_matrix(y_true, y_pred)
print("Confusion matrix shape:", cm.shape)

print(classification_report(y_true, y_pred, digits=4))


Confusion matrix shape: (31, 31)
              precision    recall  f1-score   support

           1     1.0000    1.0000    1.0000       149
           2     1.0000    1.0000    1.0000        17
           3     1.0000    1.0000    1.0000         3
           4     1.0000    1.0000    1.0000        25
           5     0.8621    1.0000    0.9259        25
           6     1.0000    1.0000    1.0000         4
           8     1.0000    1.0000    1.0000         5
           9     1.0000    1.0000    1.0000         5
          10     0.0000    0.0000    0.0000         1
          11     0.0000    0.0000    0.0000         0
          12     1.0000    1.0000    1.0000         3
          15     1.0000    0.6667    0.8000         9
          16     1.0000    1.0000    1.0000        59
          17     1.0000    1.0000    1.0000        19
          18     1.0000    1.0000    1.0000         5
          20     1.0000    1.0000    1.0000         1
          22     1.0000    1.0000    1.0000     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
