In [1]:
import numpy as np
import torch
import json
import os


In [2]:
exp_folder = "kfold_mixup_exp_01"
exp_path = os.path.join(os.getcwd(), exp_folder)

ids = []
is_fem = []
ys = []
logits = []
for folder in os.listdir(exp_path):
    with open(os.path.join(exp_path, folder, "meta.json")) as f:
        data = json.load(f)
        # ids.append(data["ids"])
        # is_fem.append(data["is_fem"])
        # ys.append(data["ys"])
        # logits.append(data["forward_logits"])
        # 
        ids += (data["ids"])
        is_fem += (data["is_fem"])
        ys += (data["ys"])
        logits += (data["forward_logits"])
        


In [3]:
ys_arr = np.asarray(ys).squeeze().astype(bool)
is_fem_arr = np.asarray(is_fem).squeeze().astype(bool)
logits_arr = np.asarray(logits).squeeze()

def sigmoid(x):
    return  1./(1. + np.exp(-x)) 

print(ys_arr.shape, logits_arr.shape)

(8458,) (8458,)


In [4]:
pred_prob_class0 = 1.0-sigmoid(logits_arr)
pred_prob_class1 = sigmoid(logits_arr)

self_conf_class0 = np.mean(pred_prob_class0[~ys_arr])
self_conf_class1 = np.mean(pred_prob_class1[ys_arr])

print(
    "self conf positiv (patient has disease) class:", self_conf_class1,
    "\nself conf negativ (patient does not) class:", self_conf_class0
)

self conf positiv (patient has disease) class: 0.3582618398562962 
self conf negativ (patient does not) class: 0.9216680489250791


In [5]:
# for every image whos given label is class i, is the predicted probability for the given 
# label being class j greater than the class j threshold. if so we add it to the counts
# matrix as the given label is class i but should have been class j.

print(np.mean(pred_prob_class1[ys_arr] > self_conf_class0))
print(np.mean(pred_prob_class0[~ys_arr] > self_conf_class1))


0.007281553398058253
0.9909271687795178


In [7]:
# for every image whos given label is class i, is the predicted probability for the given 
# label being class j greater than the class j threshold. if so we add it to the counts
# matrix as the given label is class i but should have been class j.

def find_C(pred_prob_class0, pred_prob_class1, ys_arr, ids, filter = None, s=0.2):
    if filter is not None:
        pred_prob_class0 = pred_prob_class0[filter]
        pred_prob_class1 = pred_prob_class1[filter]
        ys_arr = ys_arr[filter]
        
    C = np.zeros((2,2))
    outliers = []
    for pred0, pred1, label, id in zip(pred_prob_class0, pred_prob_class1, ys_arr, ids):
        if label == 1:
            if self_conf_class0*s > pred1:
                C[0,1] += 1
                outliers.append(id)
            else:
                C[1,1] += 1
              
        elif label == 0:
            if self_conf_class1*s > pred0:
                C[1,0] += 1
                outliers.append(id)
            else:
                C[0,0] += 1
             
    return C, outliers

print("combined\n", find_C(pred_prob_class0, pred_prob_class1, ys_arr, ids))
print("male\n", find_C(pred_prob_class0, pred_prob_class1, ys_arr, ids, filter=~is_fem_arr))
print("female\n", find_C(pred_prob_class0, pred_prob_class1, ys_arr, ids, filter=is_fem_arr))



combined
 (array([[8046.,  168.],
       [   0.,  244.]]), [3, 164, 4315, 80, 4253, 16, 4366, 71, 223, 4377, 4278, 171, 36, 91, 4290, 4230, 167, 55, 4390, 204, 4284, 4274, 176, 211, 67, 4273, 41, 4240, 157, 95, 98, 175, 4245, 4401, 4330, 114, 163, 199, 4405, 173, 181, 135, 136, 182, 4376, 88, 4393, 4238, 24, 4317, 162, 200, 105, 4370, 4355, 168, 63, 4280, 192, 187, 212, 2, 126, 99, 4334, 4367, 35, 215, 4322, 4395, 62, 148, 147, 21, 58, 156, 22, 230, 128, 4316, 4259, 29, 207, 37, 170, 46, 166, 4241, 4251, 100, 195, 56, 33, 185, 4321, 228, 178, 214, 4406, 227, 4359, 34, 107, 4308, 4270, 201, 133, 4297, 5, 4347, 4309, 84, 43, 26, 4398, 68, 12, 4256, 4279, 4323, 4380, 4319, 4264, 15, 79, 4277, 51, 81, 28, 4295, 4266, 102, 4328, 27, 154, 4231, 4294, 87, 4336, 61, 125, 90, 103, 13, 4344, 142, 4282, 174, 131, 117, 198, 4304, 110, 4285, 172, 39, 4354, 129, 108, 121, 4384, 124, 60, 123, 53, 179, 4265, 218])
male
 (array([[4051.,   60.],
       [   0.,  118.]]), [164, 7341, 8211, 5696, 6529, 511