In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
import scipy
import os
import pandas as pd
from hshap.utils import Net

os.environ["CUDA_VISIBLE_DEVICES"]="8"

device = torch.device("cuda:0")

torch.manual_seed(0)
model = Net()
weight_path = "model2.pth"
model.load_state_dict(torch.load(weight_path, map_location=device)) 
model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

data_dir = "/export/gaon1/data/jteneggi/data/synthetic/LOR"
dataset = datasets.ImageFolder(os.path.join(data_dir, "images"), transform)
image_names = [os.path.basename(sample[0]) for sample in dataset.samples]
L = len(image_names)
print("Found %d test images" % L)
true_positives = {}
classes = np.arange(0, 9) + 1
for c in classes:
    true_positives[str(c)] = []
print(true_positives)
false_negatives = []
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0)
for batch_id, (images, labels) in enumerate(dataloader):
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    _, preds = torch.max(outputs, 1)
    
    np_labels = labels.cpu().numpy()
    np_preds = preds.detach().cpu().numpy()
    
    for i, label in enumerate(np_labels):
        image_id = batch_id * 4 + i
        image_name = image_names[image_id]
        crosses_count = classes[label]
        image = os.path.join(data_dir, "images/{}/{}".format(crosses_count, image_name))
        if np_preds[i] == 1:
            true_positives[str(crosses_count)].append(image)
        if np_preds[i] == 0:
            false_negatives.append(image)

correct = 0
wrong = 0
for c in true_positives:
    print(c, len(true_positives[c]))
    correct += len(true_positives[c])
wrong += len(false_negatives)
accuracy = correct/(correct+wrong)
print("Test accuracy: %.2f" % accuracy)
np.savez("true_positives", true_positives)
print("Saved true positives")

Found 2700 test images
{'1': [], '2': [], '3': [], '4': [], '5': [], '6': [], '7': [], '8': [], '9': []}
1 298
2 300
3 300
4 300
5 300
6 300
7 300
8 300
9 300
Test accuracy: 1.00
Saved true positives


In [2]:
correct = 0
wrong = 0
for c in true_positives:
    print(c, len(true_positives[c]))
    correct += len(true_positives[c])
wrong += len(false_negatives)
print(correct/(correct+wrong), correct+wrong)

1 298
2 300
3 300
4 300
5 300
6 300
7 300
8 300
9 300
0.9992592592592593 2700
