In [1]:
import numpy as np
from scipy import stats
from sklearn import clone
from sklearn.metrics import pairwise_distances_argmin_min

from skactiveml.base import SingleAnnotatorPoolQueryStrategy, SkactivemlClassifier
from skactiveml.utils import (
    MISSING_LABEL,
    check_type,
    check_equal_missing_label,
    unlabeled_indices
)

from sklearn.datasets import make_classification
from skactiveml.classifier import SklearnClassifier
from sklearn.linear_model import LogisticRegression

In [2]:
class Badge(SingleAnnotatorPoolQueryStrategy):
    def __init__(
            self,
            missing_label=MISSING_LABEL,
            random_state=None
    ):
        super().__init__(
            missing_label=missing_label, random_state=random_state
        )

    def query(
            self,
            X,
            y,
            clf,
            candidates=None,
            batch_size=1,
            return_utilities=False,
            return_embeddings=False,
    ):
        # Validate input parameters
        X, y, candidates, batch_size, return_utilities = self._validate_data(
            X, y, candidates, batch_size, return_utilities, reset=True
        )

        X_cand, mapping = self._transform_candidates(candidates, X, y)

        # Validate classifier type
        check_type(clf, "clf", SkactivemlClassifier)
        check_equal_missing_label(clf.missing_label, self.missing_label_)
        if not isinstance(return_embeddings, bool):
            raise TypeError("'return_embeddings' must be a boolean.")

        # Fit the classifier
        clf = clone(clf).fit(X, y)

        # find the unlabeled dataset
        if candidates is None:
            X_unlbld = X_cand
            unlbld_mapping = mapping
        elif mapping is not None:
            unlbld_mapping = unlabeled_indices(y[mapping], missing_label=self.missing_label)
            X_unlbld = X_cand[unlbld_mapping]
            unlbld_mapping = mapping[unlbld_mapping]
        else:
            X_unlbld = X_cand
            unlbld_mapping = np.arange(len(X_cand))

        # Gradient embedding, aka predict class membership probabilities
        
        probas = clf.predict_proba(X_unlbld)
        print(probas)
        p_max = np.max(probas, axis=1).reshape(-1,1) #gaile
        g_x = (p_max - 1) * X_unlbld

        # init the utilities
        if mapping is not None:
            utilities = np.full(shape=(batch_size, X.shape[0]), fill_value=np.nan)
        else:
            utilities = np.full(shape=(batch_size, X_cand.shape[0]), fill_value=np.nan)

        # 2. sampling with kmeans++
        query_indicies = np.array([], dtype=int)
        query_indicies_in_unlbld = np.array([], dtype=int)
        d_2_s = []
        for i in range(batch_size):
            if i == 0:
                d_2 = _d_2(g_x, [])
            else:
                d_2 = _d_2(g_x, [idx_in_unlbld], d_2_s[i-1])
            d_2_s.append(d_2)

            d_2_sum = np.sum(d_2)
            if i == 0 or d_2_sum == 0:
                d_2 = np.ones(shape=len(g_x))
                d_2[query_indicies_in_unlbld] = 0
                d_2_sum = np.sum(d_2)
            
            d_probas = d_2 / d_2_sum
            
            utilities[i, unlbld_mapping] = d_probas
            utilities[i, query_indicies] = np.nan

            idx_in_unlbld_array = self.random_state_.choice(len(d_probas), 1, replace=False, p=d_probas)
            
            idx_in_unlbld = idx_in_unlbld_array[0]
            query_indicies_in_unlbld = np.append(query_indicies_in_unlbld, idx_in_unlbld_array)
            
            idx = unlbld_mapping[idx_in_unlbld]
            query_indicies = np.append(query_indicies, [idx])
        
        if return_utilities:
            return query_indicies, utilities
        else:
            return query_indicies


def _d_2(g_x, query_indicies, d_latest=None):
    if len(query_indicies) == 0:
        return np.full(shape=len(g_x), fill_value=np.inf)
    g_query_indicies = g_x[query_indicies]
    _, D = pairwise_distances_argmin_min(X=g_x, Y=g_query_indicies)
    if d_latest is not None:
        D = np.minimum(d_latest, D)
    D2 = np.square(D)
    D2_sum = np.sum(D2)
    if D2_sum == 0:
        return np.full(shape=len(g_x), fill_value=np.inf)
    return D2

In [3]:
X, y_true = make_classification(n_features=2, n_redundant=0, random_state=0)
y = np.full(shape=y_true.shape, fill_value=MISSING_LABEL)
clf = SklearnClassifier(LogisticRegression(), classes=np.unique(y_true))
qs = Badge(random_state=0)

In [4]:
query = qs.query(X[:5], y[:5], clf, candidates=None, batch_size=3)
print(query)

[[0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]]
[4 2 3]




In [5]:
query, utilities = qs.query(X[:5], y[:5], clf, candidates=None, batch_size=3, return_utilities=True)
print(query)
print(utilities)

[[0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]]
[4 2 3]
[[0.2        0.2        0.2        0.2        0.2       ]
 [0.23279271 0.27312903 0.30139662 0.19268163        nan]
 [0.03267711 0.00610958        nan 0.96121331        nan]]




In [6]:
y[:50] = y_true[:50]
query, utilities = qs.query(X, y, clf, candidates=None, batch_size=3, return_utilities=True)
print(query)
print(utilities)

[[0.62511502 0.37488498]
 [0.37775923 0.62224077]
 [0.0930193  0.9069807 ]
 [0.77232873 0.22767127]
 [0.62482864 0.37517136]
 [0.62863896 0.37136104]
 [0.79984337 0.20015663]
 [0.07336398 0.92663602]
 [0.48481203 0.51518797]
 [0.2684722  0.7315278 ]
 [0.25460785 0.74539215]
 [0.78241932 0.21758068]
 [0.79685072 0.20314928]
 [0.01462891 0.98537109]
 [0.17932071 0.82067929]
 [0.48403444 0.51596556]
 [0.87616605 0.12383395]
 [0.24597539 0.75402461]
 [0.71322158 0.28677842]
 [0.79553289 0.20446711]
 [0.87887735 0.12112265]
 [0.06857487 0.93142513]
 [0.00153022 0.99846978]
 [0.7635098  0.2364902 ]
 [0.0035249  0.9964751 ]
 [0.41262771 0.58737229]
 [0.28806432 0.71193568]
 [0.0257005  0.9742995 ]
 [0.42350324 0.57649676]
 [0.78020576 0.21979424]
 [0.01649936 0.98350064]
 [0.67626141 0.32373859]
 [0.0800561  0.9199439 ]
 [0.74563633 0.25436367]
 [0.94338196 0.05661804]
 [0.10634404 0.89365596]
 [0.04405831 0.95594169]
 [0.55234121 0.44765879]
 [0.43123367 0.56876633]
 [0.40180296 0.59819704]


In [7]:
query, utilities = qs.query(X, y, clf, candidates=X[50:60], batch_size=3, return_utilities=True)
print(query)
print(utilities)

[[0.62511502 0.37488498]
 [0.37775923 0.62224077]
 [0.0930193  0.9069807 ]
 [0.77232873 0.22767127]
 [0.62482864 0.37517136]
 [0.62863896 0.37136104]
 [0.79984337 0.20015663]
 [0.07336398 0.92663602]
 [0.48481203 0.51518797]
 [0.2684722  0.7315278 ]]
[2 8 4]
[[1.00000000e-01 1.00000000e-01 1.00000000e-01 1.00000000e-01
  1.00000000e-01 1.00000000e-01 1.00000000e-01 1.00000000e-01
  1.00000000e-01 1.00000000e-01]
 [1.25682550e-01 2.68979685e-02            nan 4.79296172e-02
  1.78660497e-01 9.54018139e-02 5.48850504e-02 1.12172367e-03
  4.27446035e-01 4.19747446e-02]
 [2.62050892e-01 1.20025807e-02            nan 3.81104076e-02
  4.57622164e-01 1.50990262e-01 4.99739650e-02 2.08740846e-05
             nan 2.92288546e-02]]


In [8]:
np.sum([np.inf, np.inf])

inf

In [9]:
np.random.choice(5, 1, replace=False, p=[0.1, 0, 0.3, 0.6, 0])

array([2])