In [1]:
import sys
sys.path.append('../')

import matplotlib.pyplot as plt

In [2]:
import torch
import torch.nn as nn
import numpy as np

In [3]:
np.random.seed(0)

In [4]:
n_classes, n_ex = 3, 10
y_pred = torch.tensor(np.random.rand(n_ex, n_classes))
y_pred = torch.nn.Softmax(dim=1)(y_pred)
print(y_pred)

tensor([[0.3090, 0.3649, 0.3261],
        [0.3342, 0.2961, 0.3697],
        [0.2343, 0.3691, 0.3966],
        [0.2732, 0.4109, 0.3159],
        [0.3291, 0.4706, 0.2002],
        [0.2474, 0.2313, 0.5213],
        [0.3014, 0.3304, 0.3683],
        [0.3711, 0.2647, 0.3642],
        [0.2695, 0.4541, 0.2764],
        [0.4457, 0.2920, 0.2623]], dtype=torch.float64)


In [5]:
def predict(x, model, use_cache=False, params=None, thres=None):

    if use_cache:
        y_pred, softmax = params
        if not softmax:
            y_pred = torch.nn.Softmax(dim=1)(y_pred)
    else:
        y_pred = model(x)
        y_pred = torch.nn.Softmax(dim=1)(y_pred)

    if thres is None or thres=='auto':
        # Torch argmax is slow (compared with 18 examples, torch => 23.5 secs, numpy => 1.6 secs)
        masks = torch.tensor(np.argmax(y_pred.numpy(), axis=1))
    else:
        masks = np.zeros((y_pred.size(0),))
        for cls_ in reversed(list(thres.keys())):
            masks[y_pred[:, cls_] > thres[cls_]] = cls_
    return masks


In [6]:
predict(None, None, use_cache=True, params=(y_pred, True), thres=None)

tensor([1, 2, 2, 1, 1, 2, 2, 0, 1, 0])

In [7]:
predict(None, None, use_cache=True, params=(y_pred, True), thres='auto')

tensor([1, 2, 2, 1, 1, 2, 2, 0, 1, 0])

In [8]:
# O has highest priority, align the keys in the priority order

# thres = {
#     0 : 0.42,
#     1 : 0.33,
#     2 : 0.25,
# }
thresholds = "0:0.42,1:0.33,2:0.25"
thres = {int(elem.split(':')[0]): float(elem.split(':')[1]) for elem in thresholds.split(',')}

print(predict(None, None, use_cache=True, params=(y_pred, True), thres=thres))
print(y_pred)

[1. 2. 1. 1. 1. 2. 1. 2. 1. 0.]
tensor([[0.3090, 0.3649, 0.3261],
        [0.3342, 0.2961, 0.3697],
        [0.2343, 0.3691, 0.3966],
        [0.2732, 0.4109, 0.3159],
        [0.3291, 0.4706, 0.2002],
        [0.2474, 0.2313, 0.5213],
        [0.3014, 0.3304, 0.3683],
        [0.3711, 0.2647, 0.3642],
        [0.2695, 0.4541, 0.2764],
        [0.4457, 0.2920, 0.2623]], dtype=torch.float64)


In [9]:
# Ordering matters
# thres = {
#     0 : 0.42,
#     2 : 0.25,
#     1 : 0.33,
# }
thresholds = "0:0.42,2:0.25,1:0.33"
thres = {int(elem.split(':')[0]): float(elem.split(':')[1]) for elem in thresholds.split(',')}

print(predict(None, None, use_cache=True, params=(y_pred, True), thres=thres))
print(y_pred)

[2. 2. 2. 2. 1. 2. 2. 2. 2. 0.]
tensor([[0.3090, 0.3649, 0.3261],
        [0.3342, 0.2961, 0.3697],
        [0.2343, 0.3691, 0.3966],
        [0.2732, 0.4109, 0.3159],
        [0.3291, 0.4706, 0.2002],
        [0.2474, 0.2313, 0.5213],
        [0.3014, 0.3304, 0.3683],
        [0.3711, 0.2647, 0.3642],
        [0.2695, 0.4541, 0.2764],
        [0.4457, 0.2920, 0.2623]], dtype=torch.float64)


In [24]:
thresholds = "0:.5,1:0.33,2:0.17"
thresholds = {int(elem.split(':')[0]): float(elem.split(':')[1]) for elem in thresholds.split(',')}
print(thresholds)

{0: 0.5, 1: 0.33, 2: 0.17}


In [26]:
thresholds = "0:.5,2:0.33,1:0.17"
thresholds = {int(elem.split(':')[0]): float(elem.split(':')[1]) for elem in thresholds.split(',')}
print(thresholds)

{0: 0.5, 2: 0.33, 1: 0.17}
