In [1]:
import os
import math
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import models
from analyze import XRayDataset, load_model
from collections import defaultdict
from bounding_box import load_bboxes

In [2]:
load_bboxes()
from bounding_box import bboxes

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset, loader, model = load_model(device, batch_size=1)
num_classes = len(dataset.classes)

criterion = nn.CrossEntropyLoss(weight=dataset.class_weights.to(device))

images_per_class = defaultdict(lambda: [])

for inputs, labels, filenames in loader:
    if filenames[0] not in bboxes:
        continue
    with torch.set_grad_enabled(False):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        for i in range(len(labels)):
            label = labels[i:i+1]
            pred = preds[i]
            output = outputs[i:i+1]
            loss = criterion(output, label)
            
            images_per_class[labels[i].item()].append({
                'image': filenames[i],
                'correct': bool(label == pred),
                'loss': loss.item()
            })

In [4]:
import json
best_ten = {}

for label, images in images_per_class.items():
    images = filter(lambda i: i['correct'], images)
    images = sorted(images, key=lambda i: i['loss'])
    best_ten[dataset.classes[label]] = images[:10]

with open('best3_bbox.json', 'w') as f:
    json.dump(best_ten, f, indent=2)