In [None]:
import torch
import os
import numpy as np
from tqdm.notebook import tqdm
from sklearn import metrics
from matplotlib import pyplot as plt
import h5py

np.random.seed(32635)

In [None]:
with open('../preprocessing/labels.txt') as f:
    labels = {
        int(line[22:25].strip()):
        (line[:22].strip().replace("'", ""), 
         int(line[25:].strip()))
        for line in f.readlines()
    }

eval_labels = [i for i in labels if 0 <= labels[i][1] < 255]

eval_label_names = [labels[i][0] for i in labels if 0 <= labels[i][1] < 255]

In [None]:
eval_labels

In [None]:
eval_label_names

In [None]:
path = '/mnt/lwll/lwll-coral/hrant/cs_patches_256/predictions_knn/'

In [None]:
path_to_knn='/mnt/lwll/lwll-coral/hrant/cs_patches_256/predictions_knn/sup_vit_train_72_val_2000'

In [None]:
device = 'cpu'
model = 'sup_vit'
path_to_read = '/mnt/lwll/lwll-coral/hrant/cs_patches_256/'
dino_labels_train = torch.from_numpy(np.load(path_to_read + f'{model}_labels_train_72.npy')).to(device=device, dtype=torch.int64)
dino_labels_val = torch.from_numpy(np.load(path_to_read + f'{model}_labels_val_2000.npy')).to(device=device, dtype=torch.int64)

In [None]:
dino_labels_val.shape

In [None]:
hf_path = os.path.join(path, f'{model}train_72_test_all_10x186_NN.h5')

In [None]:
file = h5py.File(hf_path, 'r')

# DINO

In [None]:
dino_dict = np.load(os.path.join(path, 'dino_val38k_10x186_NN.npy'), allow_pickle=True).item()

In [None]:
good_indices = dino_dict['good_indices']
good_values = dino_dict['good_values']

In [None]:
def topk_class(train_labels, values, indices, k=3):
    topk_indices = values.topk(k=k, largest=True).indices
#     print(topk_indices.shape)
    j = 0
    topk_labels = torch.zeros_like(indices) #create by chunk size
    for topk_ind, good_ind in zip(topk_indices, indices): # for 38000 validation set
        topk_indices = [good_ind[i] for i in topk_ind]
        tmp_labels = torch.tensor([train_labels[i] for i in topk_indices])
        topk_labels[j] = tmp_labels[torch.unique(tmp_labels, return_counts=True)[1].argmax()]
        j += 1
    return topk_labels

In [None]:
k=1
top_k_classes = []

for chunk in tqdm(file.keys()):
    chunk_good_values = torch.tensor(file[chunk]['good_values'])
    chunk_good_indices = torch.tensor(file[chunk]['good_indices'])
    topk_indices = chunk_good_values.topk(k=k, largest=True).indices
    topk_labels = torch.zeros_like(topk_indices)
    j=0
    for ind, good_index in zip(topk_indices, chunk_good_indices):
        real_index = good_index[ind]
        topk_labels[j] = torch.tensor(dino_labels_train[real_index])
        j+=1
    
    top_k_classes.append(topk_labels)
#     top_k_classes.append(topk_class(train_labels=dino_labels_train, values=chunk_good_values,
#                                    indices=chunk_good_indices, k=k))
top_k_classes = torch.cat(top_k_classes)

In [None]:
name_to_chunk = {}
for chunk_file in os.listdir(path_to_knn):
    chunk = int(chunk_file.split('_')[-1].split('.')[0])
    name_to_chunk[chunk] = chunk_file

In [None]:
name_to_chunk

In [None]:
k=1
top_k_classes = []

for _, knn_file in sorted(name_to_chunk.items()):
    knn = np.load(os.path.join(path_to_knn, knn_file), allow_pickle=True).item()
    chunk_good_values = torch.tensor(knn['good_values'])
    chunk_good_indices = torch.tensor(knn['good_indices'])
    topk_indices = chunk_good_values.topk(k=k, largest=True).indices
    topk_labels = torch.zeros_like(topk_indices)
    j=0
    for ind, good_index in zip(topk_indices, chunk_good_indices):
        real_index = good_index[ind]
        topk_labels[j] = torch.tensor(dino_labels_train[real_index])
        j+=1
    
    top_k_classes.append(topk_labels)
#     top_k_classes.append(topk_class(train_labels=dino_labels_train, values=chunk_good_values,
#                                    indices=chunk_good_indices, k=k))
top_k_classes = torch.cat(top_k_classes)


In [None]:
top_k_classes.shape

In [None]:
dino_acc_score_test = metrics.accuracy_score(dino_labels_val, top_k_classes)

In [None]:
dino_acc_score_test

In [None]:
dino_cm_test = metrics.confusion_matrix(dino_labels_val, top_k_classes, labels=eval_labels)

In [None]:
dino_dct = {f'y_pred_{k}': topk_class(dino_labels_train, k=k) for k in [1, 3, 5, 10]}

In [None]:
k = 5
dino_acc_score = metrics.accuracy_score(dino_labels_val, dino_dct[f'y_pred_{k}'])
dino_cm = metrics.confusion_matrix(dino_labels_val, dino_dct[f'y_pred_{k}'], labels=eval_labels)

In [None]:
plt.figure()
plt.imshow(dino_cm_test)
plt.yticks(range(len(eval_labels)), eval_label_names)
plt.xticks(range(len(eval_labels)), eval_label_names, rotation=90)
plt.title(f"SUP_VIT: {k}-NN classification: {dino_acc_score_test*100:.1f}% accuracy\n( train: 72 images\nvalidation: 38K patches)")
plt.colorbar();

In [None]:
top1_value_indices = good_values.argmax(-1)
top1_indices = good_indices[np.arange(good_indices.shape[0]), top1_value_indices]
top1_indices.shape

In [None]:
y_pred_top1 = dino_labels_train[top1_indices]

In [None]:
dino_acc = []
for cls, name in zip(eval_labels, eval_label_names):
    acc = metrics.accuracy_score(dino_labels_val == cls, y_pred_top1 == cls)
    dino_acc.append(acc)
#     print(f"{name:>15} = {100*acc:.1f}% acc")

In [None]:
dino_acc_score = metrics.accuracy_score(dino_labels_val, y_pred_top1)

In [None]:
dino_cm = metrics.confusion_matrix(dino_labels_val, y_pred_top1, labels=eval_labels)

In [None]:
plt.figure()
plt.imshow(dino_cm)
plt.yticks(range(len(eval_labels)), eval_label_names)
plt.xticks(range(len(eval_labels)), eval_label_names, rotation=90)
plt.title(f"DINO: 1-NN classification: {dino_acc_score*100:.1f}% accuracy\n(validation: 38K patches)")
plt.colorbar();

### MAE

In [None]:
mae_dict = np.load(os.path.join(path, 'mae_val38k_10x186_NN.npy'), allow_pickle=True).item()

In [None]:
good_indices = mae_dict['good_indices']
good_values = mae_dict['good_values']

In [None]:
mae_dct = {f'y_pred_{k}': topk_class(dino_labels_train, values=good_values, indices=good_indices, k=k) for k in [1, 3, 5, 10]}

In [None]:
k = 10
mae_acc_score = metrics.accuracy_score(dino_labels_val, mae_dct[f'y_pred_{k}'])
mae_cm = metrics.confusion_matrix(dino_labels_val, mae_dct[f'y_pred_{k}'], labels=eval_labels)

In [None]:
plt.figure()
plt.imshow(mae_cm)
plt.yticks(range(len(eval_labels)), eval_label_names)
plt.xticks(range(len(eval_labels)), eval_label_names, rotation=90)
plt.title(f"MAE: {k}-NN classification: {mae_acc_score*100:.1f}% accuracy\n(validation: 38K patches)")
plt.colorbar();

In [None]:
top1_value_indices = good_values.argmax(-1)

In [None]:
top1_indices = good_indices[np.arange(good_indices.shape[0]), top1_value_indices]

In [None]:
y_pred_top1 = dino_labels_train[top1_indices]

In [None]:
mae_acc = []
for cls, name in zip(eval_labels, eval_label_names):
    acc = metrics.accuracy_score(dino_labels_val == cls, y_pred_top1 == cls)
    mae_acc.append(acc)
#     print(f"{name:>15} = {100*acc:.1f}% acc")

In [None]:
mae_acc_score = metrics.accuracy_score(dino_labels_val, y_pred_top1)

In [None]:
mae_cm = metrics.confusion_matrix(dino_labels_val, y_pred_top1, labels=eval_labels)

In [None]:
plt.figure()
plt.imshow(mae_cm)
plt.yticks(range(len(eval_labels)), eval_label_names)
plt.xticks(range(len(eval_labels)), eval_label_names, rotation=90)
plt.title(f"MAE: 1-NN classification: {mae_acc_score*100:.1f}% accuracy\n(validation: 38K patches)")
plt.colorbar();

### Both

In [None]:
k = 5
dino_acc = []
for cls, name in zip(eval_labels, eval_label_names):
    acc = metrics.accuracy_score(dino_labels_val == cls, dino_dct[f'y_pred_{k}'] == cls)
    dino_acc.append(acc)


mae_acc = []
for cls, name in zip(eval_labels, eval_label_names):
    acc = metrics.accuracy_score(dino_labels_val == cls, mae_dct[f'y_pred_{k}'] == cls)
    mae_acc.append(acc)

plt.figure()
plt.barh(eval_label_names, dino_acc, label='DINO', alpha=0.5);
plt.barh(eval_label_names, mae_acc, label='MAE', alpha=0.5);
plt.title(f"Per-class accuracy using {k}-NN classifier")
plt.legend()
plt.xlim(0.6,1.05)

In [None]:
k = 10
dino_10 = []
for cls, name in zip(eval_labels, eval_label_names):
    acc = metrics.accuracy_score(dino_labels_val == cls, dino_dct[f'y_pred_{k}'] == cls)
    dino_10.append(acc)

k = 1
dino_1 = []
for cls, name in zip(eval_labels, eval_label_names):
    acc = metrics.accuracy_score(dino_labels_val == cls, dino_dct[f'y_pred_{k}'] == cls)
    dino_1.append(acc)

plt.figure()
plt.barh(eval_label_names, dino_1, label='DINO k=1', alpha=0.5);
plt.barh(eval_label_names, dino_10, label='DINO k=10', alpha=0.5);
plt.title(f"Per-class accuracy using for DINO")
plt.legend(bbox_to_anchor=(1, 0.5));
# plt.xlim(0.6,1.15)

In [None]:
k = 10
mae_10 = []
for cls, name in zip(eval_labels, eval_label_names):
    acc = metrics.accuracy_score(dino_labels_val == cls, mae_dct[f'y_pred_{k}'] == cls)
    mae_10.append(acc)

k = 1
mae_1 = []
for cls, name in zip(eval_labels, eval_label_names):
    acc = metrics.accuracy_score(dino_labels_val == cls, mae_dct[f'y_pred_{k}'] == cls)
    mae_1.append(acc)

plt.figure()
plt.barh(eval_label_names, mae_1, label='MAE k=1', alpha=0.5);
plt.barh(eval_label_names, mae_10, label='MAE k=10', alpha=0.5);
plt.title(f"Per-class accuracy using M")
plt.legend(bbox_to_anchor=(1, 0.5));
# plt.xlim(0.6,1.05)

In [None]:
plt.figure()
plt.barh(eval_label_names, dino_acc, label='DINO', alpha=0.5);
plt.barh(eval_label_names, mae_acc, label='MAE', alpha=0.5);
plt.title("Per-class accuracy using 1-NN classifier")
plt.legend()
plt.xlim(0.6,1.05)

In [None]:
np.array(mae_acc) - np.array(dino_acc)

In [None]:
plt.figure()
plt.barh(eval_label_names, np.array(mae_acc) - np.array(dino_acc));
# plt.barh(eval_label_names, mae_acc, label='MAE', alpha=0.5);
plt.title("Per-class accuracy using 1-NN classifier")

In [None]:
from PIL import Image
labelids = Image.open('/mnt/lwll/lwll-coral/hrant/gtFine/test/bonn/bonn_000045_000019_gtFine_labelIds.png')

In [None]:
np.array(labelids)