In [20]:


import numpy as np

from skactiveml.base import SingleAnnotatorPoolQueryStrategy
from skactiveml.utils import MISSING_LABEL, labeled_indices
from sklearn.datasets import make_classification
from sklearn.metrics import pairwise_distances
from sklearn.cluster import KMeans
from sklearn.base import ClusterMixin

In [21]:
class TypiClust(SingleAnnotatorPoolQueryStrategy):
    def __init__(
        self,
        missing_label=MISSING_LABEL,
        random_state=None,
    ):
        super().__init__(
            missing_label=missing_label, random_state=random_state
        )

    def query(
        self,
        X,
        y,
        clust_algo=KMeans,
        k=5,
        candidates=None,
        batch_size=1,
        return_utilities=False,
    ):
        X, y, candidates, batch_size, return_utilities = self._validate_data(
            X, y, candidates, batch_size, return_utilities, reset=True
        )

        X_cand, mapping = self._transform_candidates(candidates, X, y)

        # Validate Clustering Algorithm?
        if not issubclass(clust_algo, ClusterMixin):
            raise TypeError("Only clustering algorithm from super class sklearn.ClusterMixin is supported.")

        if not isinstance(k, int):
            raise TypeError("Only k as integer is supported.")

        selected_samples = labeled_indices(y, missing_label=self.missing_label)
        n_clusters = len(selected_samples) + batch_size

        clustering_algo = clust_algo(n_clusters=n_clusters, random_state=self.random_state)

        cluster_labels = clustering_algo.fit_predict(X)
        cluster_ids, cluster_sizes = np.unique(cluster_labels, return_counts=True)

        print(cluster_sizes)

        covered_cluster = np.unique([cluster_labels[i] for i in selected_samples])
        print(covered_cluster[0].__class__)

        #for c_idx in covered_cluster:
        #    cluster_sizes[c_idx] = 0
        cluster_sizes[covered_cluster] = 0

        print(cluster_sizes)

        utilities = np.zeros(shape=(batch_size, X.shape[0]))
        query_indices = []

        for i in range(batch_size):
            cluster_id = np.argmax(cluster_sizes)
            uncovered_samples_mapping = [idx for idx, value in enumerate(cluster_labels) if value == cluster_id]
            typicality = _typicality(X, uncovered_samples_mapping, k)
            idx = np.argmax(typicality)
            typicality[selected_samples] = np.nan
            utilities[i] = typicality

            query_indices = np.append(query_indices, [idx])
            selected_samples = np.append(selected_samples, [idx])
            cluster_sizes[cluster_ids] = 0

        if return_utilities:
            return query_indices, utilities
        else:
            return query_indices


def _typicality(X, uncovered_samples_mapping, k):
    typicality = np.zeros(shape=X.shape[0])
    dist_matrix = pairwise_distances(X[uncovered_samples_mapping])
    dist_matrix_sort_inc = np.sort(dist_matrix)
    knn = np.sum(dist_matrix_sort_inc[:, :k+1], axis=1)
    typi = 1 / (1 / k * knn)
    for idx, value in enumerate(uncovered_samples_mapping):
        typicality[value] = typi[idx]
    return typicality

In [22]:
X, y_true = make_classification(n_features=2, n_redundant=0, random_state=0)
y = np.full(shape=y_true.shape, fill_value=MISSING_LABEL)
mapping = np.arange(3)
y[mapping] = y_true[mapping]

In [23]:
qs = TypiClust()

In [24]:
query_indices, utilites = qs.query(X, y, batch_size=2, return_utilities=True)
print(query_indices)
print(utilites)

[31 30 26  9  4]
<class 'numpy.int32'>
[ 0 30  0  9  4]
[27. 53.]
[[       nan        nan        nan 2.22700498 0.         0.
  0.         0.         0.         0.         0.         0.57105701
  0.         0.         0.         3.07314926 0.         0.
  0.         0.         0.         2.29476449 1.80667844 0.
  2.49912811 0.         0.         4.20356886 0.         0.
  0.         0.         3.23772571 0.         2.53191068 3.11016292
  0.         1.15298427 2.40325448 2.8994196  0.         1.45326822
  0.         0.         0.         0.         1.03744451 0.
  0.         0.         0.         0.         2.21991985 0.
  0.         0.         0.         3.26136825 0.         1.79419345
  3.60324171 0.         0.         0.         1.73903749 0.
  0.         3.95598081 0.         0.         0.         3.12695827
  0.         0.         0.         0.         0.         2.23917218
  0.         0.         0.         0.         1.96864763 0.
  0.         2.58636555 1.03013844 0.         

