In [None]:
import torch
import os
import torch.nn.functional as F
import numpy as np

In [None]:
def load_from_file(filename):
    return torch.load(filename, map_location=torch.device('cpu'))

In [None]:
torch.set_printoptions(threshold=10000000)

##### DISCRIMINATIVE
src_logits = load_from_file('resources/tensors/DGCNN_CE_SR1/src_logits.pt')
tar1_logits = load_from_file('resources/tensors/DGCNN_CE_SR1/tar1_logits.pt')
tar2_logits = load_from_file('resources/tensors/DGCNN_CE_SR1/tar2_logits.pt')

src_MSP_scores, src_MSP_pred = F.softmax(src_logits, dim=1).max(1)
tar1_MSP_scores, tar1_MSP_pred = F.softmax(tar1_logits, dim=1).max(1)
tar2_MSP_scores, tar2_MSP_pred = F.softmax(tar2_logits, dim=1).max(1)

##### DISTANCE BASED 
src_dist = load_from_file('resources/tensors/DGCNN_CE_SR1/src_dist.pt')
src_ids = load_from_file('resources/tensors/DGCNN_CE_SR1/src_ids.pt')
tar1_dist = load_from_file('resources/tensors/DGCNN_CE_SR1/tar1_dist.pt')
tar1_ids = load_from_file('resources/tensors/DGCNN_CE_SR1/tar1_ids.pt')
tar2_dist = load_from_file('resources/tensors/DGCNN_CE_SR1/tar2_dist.pt')
tar2_ids = load_from_file('resources/tensors/DGCNN_CE_SR1/tar2_ids.pt')

src_feats = load_from_file('resources/feats/DGCNN_CE_SR1.pt')

# set OOD labels
tar1_labels = torch.from_numpy(np.full((1,len(tar1_logits)), 404)).squeeze().cpu()
tar2_labels = torch.from_numpy(np.full((1,len(tar2_logits)), 404)).squeeze().cpu()
src_labels = load_from_file('resources/tensors/DGCNN_CE_SR1/src_labels.pt')

In [None]:
import matplotlib.pyplot as plt

def accuracy(samples, labels, truth, threshold) -> float:
    acc = 0  #sum of all correct predictions
    i = 0    #number of iterations
    pred = 0 #temporary variable to store the prediction relative to the current sample
    for s in samples:
        if s > threshold:
            pred = labels[i]
        else:
            pred = 404
        if pred == truth[i]:
            acc += 1
        i += 1
    return acc/i

def calculate_threshold(scores, preds, truth) -> float:
    thresholds = np.linspace(0, 1, 100)
    accuracies = []
    for t in thresholds:
        accuracies.append(accuracy(scores, preds, truth, t))

    plt.title('Threshold effect on model accuracy')
    plt.xlabel('Threshold')
    plt.ylabel('Accuracy')
    plt.plot(thresholds, accuracies)

    return thresholds[np.argmax(accuracies)]

In [None]:
samples_MSP = torch.hstack((src_MSP_scores, tar1_MSP_scores, tar2_MSP_scores))
labels_MSP = torch.hstack((src_MSP_pred, tar1_MSP_pred, tar2_MSP_pred))
truth = torch.hstack((src_labels, tar1_labels, tar2_labels))
threshold_MSP = calculate_threshold(samples_MSP, labels_MSP, truth)

print("Best value for threshold: " + str(threshold_MSP))

In [None]:
# train truth
train_labels = src_feats["train_labels"]

src_L2_dist = src_dist.squeeze().cpu()
src_L2_ids = src_ids.squeeze().cpu()  # index of nearest training sample
src_L2_scores = 1 / src_dist
src_L2_pred = torch.from_numpy(train_labels[src_ids])

tar1_L2_dist = tar1_dist.squeeze().cpu()
tar1_L2_ids = tar1_ids.squeeze().cpu()  # index of nearest training sample
tar1_L2_scores = 1 / tar1_dist
tar1_L2_pred = torch.from_numpy(train_labels[tar1_ids])

tar2_L2_dist = tar2_dist.squeeze().cpu()
tar2_L2_ids = tar2_ids.squeeze().cpu()  # index of nearest training sample
tar2_L2_scores = 1 / tar2_dist
tar2_L2_pred = torch.from_numpy(train_labels[tar2_ids])

ids_L2 = torch.hstack((src_L2_ids, tar1_L2_ids, tar2_L2_ids))
scores_L2 = torch.hstack((src_L2_scores, tar1_L2_scores, tar2_L2_scores))
labels_L2 = torch.hstack((src_L2_pred, tar1_L2_pred, tar2_L2_pred))
truth = torch.hstack((src_labels, tar1_labels, tar2_labels))

l2_threshold = calculate_threshold(scores_L2, labels_L2, truth)
print("Best value for threshold: " + str(l2_threshold))

# misclassified samples
misclas_samples_ids = (~labels_L2.eq(truth)).nonzero(as_tuple=True)

print(misclas_samples_ids)

