# Load Data

In [1]:
from sklearn import datasets
iris = datasets.load_iris()
list(iris.keys())

['data',
 'target',
 'frame',
 'target_names',
 'DESCR',
 'feature_names',
 'filename',
 'data_module']

In [2]:
iris["data"][0]

array([5.1, 3.5, 1.4, 0.2])

In [3]:
import numpy as np

X = iris["data"]
y = (iris["target"] == 2).astype(np.int)  # 1 if Iris virginica, else 0

In [4]:
y_true = y.copy()
y_experiment = y_true.copy()

In [5]:
rng = np.random.RandomState(42)
random_unlabeled_points = rng.rand(y_experiment.shape[0]) < 0.3
y_experiment[random_unlabeled_points] = -1

In [6]:
X.shape

(150, 4)

# RFoT

In [13]:
from RFoT import RFoT

model = RFoT(
        bin_scale=1,
        max_dimensions=3,
        component_purity_tol=1.0,
        min_rank=2,
        max_rank=3,
        n_estimators=50,
        bin_entry=True,
        clustering="ms",
        max_depth=2,
        n_jobs=50,
)
y_pred = model.predict(X, y_experiment)

100%|██████████| 22/22 [00:00<00:00, 63.61it/s]
100%|██████████| 22/22 [00:00<00:00, 225.14it/s]


# Look at the results

In [14]:
from sklearn.metrics import f1_score

unknown_indices = np.argwhere(y_experiment == -1).flatten()
did_predict_indices = np.argwhere(y_pred[unknown_indices] != -1).flatten()
abstaining_count = len(np.argwhere(y_pred == -1))
f1 = f1_score(
    y_true[unknown_indices][did_predict_indices],
    y_pred[unknown_indices][did_predict_indices],
    average="weighted",
)

print("------------------------")
print("Num. of Abstaining", abstaining_count)
print("Percent Abstaining", (abstaining_count / len(unknown_indices)) * 100, "%")
print("F1=", f1)

------------------------
Num. of Abstaining 17
Percent Abstaining 33.33333333333333 %
F1= 1.0


In [15]:
from sklearn.metrics import classification_report

y_true_hat = y_true[unknown_indices][did_predict_indices]
y_pred_hat = y_pred[unknown_indices][did_predict_indices]
print(classification_report(y_true_hat, y_pred_hat))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        19
           1       1.00      1.00      1.00        15

    accuracy                           1.00        34
   macro avg       1.00      1.00      1.00        34
weighted avg       1.00      1.00      1.00        34

