In [217]:
import numpy as np
import metrics
from sklearn.linear_model import RidgeClassifierCV

ridge = RidgeClassifierCV()

In [218]:
X = np.load('x_se_resnext50_32x4d_0.npy')
y = np.load('y_se_resnext50_32x4d_0.npy')

In [219]:
from sklearn.cross_validation import train_test_split

def sigmoid(x):
    return 1/(1 + np.exp(-x))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25)
X_train = sigmoid(X_train)
X_test = sigmoid(X_test)

In [220]:
def generate_preds(pred, threshold):
    integer = np.arange(28)[pred > threshold]
    a = np.zeros(28)
    [np.put(a, i, 1) for i in(integer)]
    
    return a

def threshold_error_mapping(threshold):
    final_pred = np.array([generate_preds(X_train[i], threshold) for i in range(len(y_train))])
    return -metrics.f1_macro(final_pred, y_train)

In [221]:
from skopt import gp_minimize

res = gp_minimize(threshold_error_mapping,                  # the function to minimize
                  [(0.4, .6)]*28,      # the bounds on each dimension of x
                  acq_func="EI",      # the acquisition function
                  n_calls=100,         # the number of evaluations of f 
                  n_random_starts=5,  # the number of random initialization points
                  noise=0.1**2,       # the noise level (optional)
                  random_state=123)   # the random seed

In [222]:
print("Value of f1 macro score: {}".format(res.fun))
print("Thresholds: {}".format(res.x))

Value of f1 macro score: -0.7017762560866011
Thresholds: [0.4, 0.6, 0.4, 0.6, 0.4, 0.4, 0.5951117108331118, 0.4853869207383187, 0.6, 0.5687016706298437, 0.4573993807698044, 0.4, 0.4, 0.4653931447260238, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.58698943976186, 0.6, 0.556517758252218, 0.6, 0.5950956144670102, 0.41854161523071076, 0.4, 0.4]


array([[3.19453228e-03, 1.70492440e-04, 3.79152266e-04, ...,
        1.50876690e-04, 8.26269841e-08, 8.68114516e-08],
       [9.85534891e-01, 1.22810064e-02, 1.45901180e-03, ...,
        1.28603279e-04, 5.89215142e-06, 5.63633814e-06],
       [9.70043808e-02, 9.99874830e-01, 2.11967796e-04, ...,
        9.84548115e-03, 1.09272171e-05, 8.02535389e-07],
       ...,
       [4.50812949e-05, 1.56078614e-06, 3.88409983e-05, ...,
        5.69586953e-04, 3.06504216e-04, 1.46943207e-08],
       [7.35051167e-01, 4.76892167e-07, 1.41028430e-06, ...,
        3.30726390e-04, 3.41177657e-06, 1.63957556e-09],
       [2.57525461e-04, 7.08599209e-07, 2.47160458e-05, ...,
        1.81226965e-05, 3.47075806e-08, 6.99828128e-09]])