In [1]:
import torch
import numpy as np
from scipy.optimize import linear_sum_assignment as linear_assignment

In [2]:
fine_preds = torch.load('./preds_result/cifar100_5.pt')['preds']
targets = torch.load('./preds_result/cifar100_5.pt')['targets']
print(fine_preds)
print(targets)

[49 33 55 ... 51 42 70]
[49. 33. 72. ... 51. 42. 70.]


In [3]:
train_classes = range(80)
cls_mask = np.array([True if x in range(len(train_classes))
                                         else False for x in targets])
cls_mask

array([ True,  True,  True, ...,  True,  True,  True])

In [4]:
def get_cifar100_coarse_labels(fine_labels):
    coarse_labels = np.array([ 4,  1, 14,  8,  0,  6,  7,  7, 18,  3,
                                    3, 14,  9, 18,  7, 11,  3,  9,  7, 11,
                                    6, 11,  5, 10,  7,  6, 13, 15,  3, 15, 
                                    0, 11,  1, 10, 12, 14, 16,  9, 11,  5,
                                    5, 19,  8,  8, 15, 13, 14, 17, 18, 10,
                                    16, 4, 17,  4,  2,  0, 17,  4, 18, 17,
                                    10, 3,  2, 12, 12, 16, 12,  1,  9, 19, 
                                    2, 10,  0,  1, 16, 12,  9, 13, 15, 13,
                                    16, 19,  2,  4,  6, 19,  5,  5,  8, 19,
                                    18,  1,  2, 15,  6,  0, 17,  8, 14, 13])
    return coarse_labels[fine_labels]

def split_cluster_acc_v2(y_true, y_pred, mask):
    """
    Calculate clustering accuracy. Require scikit-learn installed
    First compute linear assignment on all data, then look at how good the accuracy is on subsets

    # Arguments
        mask: Which instances come from old classes (True) and which ones come from new classes (False)
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(int)

    old_classes_gt = set(y_true[mask])
    new_classes_gt = set(y_true[~mask])

    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=int)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    ind = linear_assignment(w.max() - w)
    ind = np.vstack(ind).T

    ind_map = {j: i for i, j in ind}
    total_acc = sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

    old_acc = 0
    total_old_instances = 0
    for i in old_classes_gt:
        old_acc += w[ind_map[i], i]
        total_old_instances += sum(w[:, i])
    old_acc /= total_old_instances

    new_acc = 0
    total_new_instances = 0
    for i in new_classes_gt:
        new_acc += w[ind_map[i], i]
        total_new_instances += sum(w[:, i])
    new_acc /= total_new_instances

    return total_acc, old_acc, new_acc, ind

In [5]:
total_acc, old_acc, new_acc, ind = split_cluster_acc_v2(targets, fine_preds, cls_mask)
print(total_acc, old_acc, new_acc)
print(ind)

0.8231 0.854 0.6995
[[ 0  0]
 [ 1  1]
 [ 2  2]
 [ 3  3]
 [ 4  4]
 [ 5  5]
 [ 6  6]
 [ 7  7]
 [ 8  8]
 [ 9  9]
 [10 10]
 [11 11]
 [12 12]
 [13 13]
 [14 14]
 [15 15]
 [16 16]
 [17 17]
 [18 18]
 [19 19]
 [20 20]
 [21 21]
 [22 22]
 [23 23]
 [24 24]
 [25 25]
 [26 26]
 [27 27]
 [28 28]
 [29 29]
 [30 95]
 [31 31]
 [32 32]
 [33 33]
 [34 34]
 [35 35]
 [36 36]
 [37 37]
 [38 38]
 [39 39]
 [40 40]
 [41 41]
 [42 42]
 [43 43]
 [44 44]
 [45 45]
 [46 46]
 [47 47]
 [48 48]
 [49 49]
 [50 50]
 [51 51]
 [52 52]
 [53 53]
 [54 54]
 [55 55]
 [56 56]
 [57 57]
 [58 58]
 [59 59]
 [60 60]
 [61 61]
 [62 62]
 [63 63]
 [64 64]
 [65 65]
 [66 66]
 [67 67]
 [68 68]
 [69 69]
 [70 70]
 [71 71]
 [72 72]
 [73 73]
 [74 74]
 [75 75]
 [76 76]
 [77 77]
 [78 78]
 [79 79]
 [80 89]
 [81 85]
 [82 82]
 [83 92]
 [84 30]
 [85 93]
 [86 88]
 [87 87]
 [88 90]
 [89 86]
 [90 97]
 [91 80]
 [92 96]
 [93 98]
 [94 81]
 [95 91]
 [96 94]
 [97 83]
 [98 99]
 [99 84]]


In [6]:
ind_target2coarse_map = {i:get_cifar100_coarse_labels(j) for i, j in ind}
target2coarse_preds = np.vectorize(ind_target2coarse_map.get)(fine_preds)
coarse_targets = get_cifar100_coarse_labels(targets.astype(np.int64))
acc = (target2coarse_preds == coarse_targets).mean()
acc

0.9091

In [7]:
target2coarse_preds

array([10, 10,  0, ...,  4,  8,  2])