In [1]:
%load_ext autoreload
%autoreload 2

In [23]:
import numpy as np
import plotly.graph_objects as go
from collections import Counter, defaultdict
from imagenet.tree import get_imagenet_structure

In [13]:
data = np.load("self_supervised_features.npz")
knn_data = np.load("knn.npz")

In [29]:
tree = get_imagenet_structure("structure_released.xml")

wnid_to_name = {}
to_expand = [tree.root]
while len(to_expand) > 0:
    node = to_expand.pop(0)
    wnid_to_name.update({node.wnid: node.name})
    to_expand.extend(node.children)
    
idx_to_class = data["idx_to_class"]

In [50]:
n_classes = [50, 100, 500]
K = [1, 5, 10, 50]


mean_acc_dict = defaultdict()
acc_per_class_dict = defaultdict()
neighbor_with_same_label_per_class_dict = defaultdict()
target_in_neighbor_per_class_dict = defaultdict()

for n_c in n_classes:
    
    npr = np.random.RandomState(123)
    class_idx = npr.choice(1000, n_c, replace=False)
    data_idx = [i for i, t in enumerate(data["target"]) if t in class_idx]
    gts = data["target"][data_idx]
    nns = [[i for i in t if i in class_idx] for t in knn_data["target_arr"][data_idx]]
    
    print("{} classes".format(n_c))
    for k in K:
        mean_acc = []
        neighbor_with_same_label = []
        target_in_neighbor = []
        
        acc_per_class = defaultdict(lambda: [])
        neighbor_with_same_label_per_class = defaultdict(lambda: [])
        target_in_neighbor_per_class = defaultdict(lambda: [])
        
        for nn, gt in zip(nns, gts):
            knn = Counter(nn[:k])
            pred = knn.most_common(1)[0][0]
            
            mean_acc.append(pred == gt)
            acc_per_class[wnid_to_name[idx_to_class[gt]]].append(pred == gt)
            
            neighbor_with_same_label.append(knn[gt])
            neighbor_with_same_label_per_class[wnid_to_name[idx_to_class[gt]]].append(knn[gt])
            
            target_in_neighbor.append(knn[gt] > 0)
            target_in_neighbor_per_class[wnid_to_name[idx_to_class[gt]]].append(knn[gt] > 0)
            
            
        mean_acc = np.mean(mean_acc)
        mean_acc_dict["{}.{}".format(n_c, k)] = mean_acc
        acc_per_class = {_k: np.mean(_v) for _k, _v in acc_per_class.items()}
        acc_per_class_dict["{}.{}".format(n_c, k)] = acc_per_class
        
        neighbor_with_same_label = np.mean(neighbor_with_same_label) / k
        neighbor_with_same_label_per_class = {_k: np.mean(_v) / k for _k, _v in neighbor_with_same_label_per_class.items()}
        neighbor_with_same_label_per_class_dict["{}.{}".format(n_c, k)] = neighbor_with_same_label_per_class
        
        target_in_neighbor = np.mean(target_in_neighbor)
        target_in_neighbor_per_class = {_k: np.mean(_v) for _k, _v in target_in_neighbor_per_class.items()}
        target_in_neighbor_per_class_dict["{}.{}".format(n_c, k)] = target_in_neighbor_per_class
        
        print("{}-NN\tmean accuracy: {:.2f}\t{:.2f}% semantically-similar neighbor\t{:.2f}% of times found target in neighbor".format(k, mean_acc, neighbor_with_same_label*100, target_in_neighbor*100))

50 classes
1-NN	mean accuracy: 0.81	81.40% semantically-similar neighbor	81.40% of times found target in neighbor
5-NN	mean accuracy: 0.83	74.35% semantically-similar neighbor	92.96% of times found target in neighbor
10-NN	mean accuracy: 0.84	68.42% semantically-similar neighbor	95.68% of times found target in neighbor
50-NN	mean accuracy: 0.83	36.34% semantically-similar neighbor	97.88% of times found target in neighbor
100 classes
1-NN	mean accuracy: 0.75	74.58% semantically-similar neighbor	74.58% of times found target in neighbor
5-NN	mean accuracy: 0.77	66.86% semantically-similar neighbor	88.32% of times found target in neighbor
10-NN	mean accuracy: 0.78	60.96% semantically-similar neighbor	92.26% of times found target in neighbor
50-NN	mean accuracy: 0.76	33.63% semantically-similar neighbor	97.88% of times found target in neighbor
500 classes
1-NN	mean accuracy: 0.59	59.15% semantically-similar neighbor	59.15% of times found target in neighbor
5-NN	mean accuracy: 0.62	50.66% se

In [51]:
def plot(dict, keys, title, n_sample=20):
    
    fig = go.Figure()
    
    npr = np.random.RandomState(123)
    x = list(dict[keys[0]].keys())
    x.sort()
    x = npr.choice(x, n_sample, replace=False)
    y = [dict[keys[0]][i] for i in x]
    idx = np.argsort(y)[::-1]
    x = [x[i] for i in idx]
    
    for key in keys:
        y = [dict[key][i] for i in x]
        
        fig.add_trace(go.Bar(name=key, x=x, y=y))
    

    fig.update_layout(barmode='group')
    fig.update_layout(title_text=title)
    fig.update_yaxes(range=[0, 1])
    fig.show()

In [52]:
for n_c in n_classes:
    plot(acc_per_class_dict, 
         keys=["{}.{}".format(n_c, k) for k in K], 
         title="{} classes\tAccuracy per class".format(n_c))

In [53]:
for n_c in n_classes:
    plot(neighbor_with_same_label_per_class_dict, 
         keys=["{}.{}".format(n_c, k) for k in K], 
         title="{} classes\tNeighbor with the same labe".format(n_c))

In [54]:
for n_c in n_classes:
    plot(target_in_neighbor_per_class_dict, 
         keys=["{}.{}".format(n_c, k) for k in K], 
         title="{} classes\tTarget in the neighbor".format(n_c))