In [None]:
import os
import math
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, Dataset
from PIL import Image
from torchvision import models

In [None]:
CSV_FILE = os.path.join('..', 'data', 'Data_Entry_2017.csv')
BATCH_SIZE = 10

In [None]:
import csv

class XRayDataset(Dataset):
    def __init__(self, transform, validation=False):
        self.transform = transform
        self.files = []
        if not os.path.exists(CSV_FILE):
            raise Exception('missing csv data file {}, please download data as described in README.md'.format(CSV_FILE))

        self.classes = set()

        with open(CSV_FILE) as csvfile:
            reader = csv.reader(csvfile, delimiter=',')
            next(reader) # skip header
            for row in reader:
                filename, labels, *_ = row
                labels = labels.split('|')
                # only use images with a single class
                if len(labels) != 1 or labels[0] == 'No Finding':
                    continue
                self.files.append((filename, labels[0]))
                self.classes.update(labels)

        # convert set to list to have a guaranteed iteration order
        # this should also be the case with a set, but it is not explictly defined
        self.classes = sorted(list(self.classes))

    def __getitem__(self, index):
        filename, label = self.files[index]
        image = Image.open(os.path.join('..', 'data', 'processed_images', filename))
        image = self.transform(image)
        return image, torch.tensor(self.classes.index(label), dtype=torch.long)

    def __len__(self):
        return len(self.files)

In [None]:
dataset = XRayDataset(transform=transforms.ToTensor())
loader = DataLoader(dataset, batch_size=BATCH_SIZE)

num_classes = len(dataset.classes)

model = models.densenet121()
num_ftrs = model.classifier.in_features
model.classifier = nn.Sequential(nn.Linear(num_ftrs, num_classes), nn.Sigmoid())
state_dict = torch.load('../densenet_single_full_nonofindings_3_0.29417.pth')
model.load_state_dict(state_dict)

In [None]:
model.eval()   # Set model to evaluate mode

total_per_class = [0] * num_classes
correct_per_class = [0] * num_classes

j = 0
for inputs, labels in loader:
    j += 1
    if j % 10 == 0:
        print(total_per_class, correct_per_class)
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)
    
    for i in range(len(labels)):
        label = labels[i]
        pred = preds[i]
        total_per_class[label] += 1
        if label == pred:
            correct_per_class[label] += 1
