In [57]:
import numpy as np

from skactiveml.base import SingleAnnotatorPoolQueryStrategy, SkactivemlClassifier
from skactiveml.utils import (
    MISSING_LABEL,
    check_type,
    check_equal_missing_label
)

from sklearn.linear_model import LogisticRegression
from sklearn.datasets import make_classification

from skactiveml.classifier import SklearnClassifier

In [58]:
class Badge(SingleAnnotatorPoolQueryStrategy):
    def __init__(
            self,
            missing_label=MISSING_LABEL,
            random_state=None
    ):
        super().__init__(
            missing_label=missing_label, random_state=random_state
        )

    def query(
        self,
        X,
        y,
        clf,
        candidates=None,
        batch_size=1,
        return_utilities=False,
    ):
        # Validate input parameters
        X, y, candidates, batch_size, return_utilities = self._validate_data(
            X, y, candidates, batch_size, return_utilities, reset=True
        )

        X_cand, mapping = self._transform_candidates(candidates, X, y)

        # Validate classifier type
        check_type(clf, "clf", SkactivemlClassifier)
        check_equal_missing_label(clf.missing_label, self.missing_label_)
        
        clf = clf.fit(X, y)
        probas = clf.predict_proba(X_cand)
        print(probas)
        p_max = np.max(probas, axis=-1).reshape(-1, 1)
    
        print(p_max)
        g_x = (p_max - 1) * X_cand
        print(g_x)
        

In [65]:
X, y_true = make_classification(n_features=2, n_redundant=0, random_state=0)

X.shape[0]

100

In [60]:
y = np.full(shape=y_true.shape, fill_value=MISSING_LABEL)

In [64]:
clf = SklearnClassifier(LogisticRegression(), classes=np.unique(y_true))
print(np.unique(y_true))

[0 1]


In [62]:
qs = Badge()

In [63]:
qs.query(X[:5], y[:5], clf, candidates=None, batch_size=1)

[[0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]
 [0.5 0.5]]
[[0.5]
 [0.5]
 [0.5]
 [0.5]
 [0.5]]
[[ 0.49329254  1.38629638  0.14346   ]
 [ 0.57427888  0.19547669 -0.43291273]
 [-1.67986627 -0.07097658  0.89753172]
 [-0.25640728 -0.26053244 -0.21649727]
 [ 0.46019163 -0.04870008  0.0361584 ]]


