In [31]:
import pandas as pd
import clustbench
from skquery.pairwise import FFQS
from skquery.oracle import MLCLOracle
from active_semi_clustering.semi_supervised.pairwise_constraints import COPKMeans
from sklearn.decomposition import PCA
from sklearn.metrics import adjusted_rand_score
import plotly.express as px
import plotly.graph_objects as go

In [32]:
#Plots data partition, optionnally with ML/CL constraints
def plot(dataset, partition, constraints=None, filename=None):
    viz_dataset = pd.DataFrame(PCA(n_components=2).fit_transform(dataset)) if dataset.shape[1] > 3 else pd.DataFrame(dataset)
    fig = None
    match viz_dataset.shape[1]:
        case 2:
            fig = px.scatter(viz_dataset, x=0, y=1, template="simple_white",
                             color=partition, symbol=partition,
                             hover_data={'index': viz_dataset.index.astype(str)})
        case 3:
            fig = px.scatter_3d(viz_dataset, x=0, y=1, z=2, template="simple_white",
                                color=partition, symbol=partition,
                                hover_data={'index': viz_dataset.index.astype(str)})

    if constraints:
        for key in constraints:
            for cst in constraints[key]:
                points = viz_dataset.iloc[list(cst)]
                match viz_dataset.shape[1]:
                    case 2:
                        fig.add_trace(go.Scatter(name=str(cst), x=[points.iloc[0, 0], points.iloc[1, 0]],
                                                 mode="lines", y=[points.iloc[0, 1], points.iloc[1, 1]]))
                    case 3:
                        fig.add_trace(go.Scatter3d(name=str(cst), x=[points.iloc[0, 0], points.iloc[1, 0]],
                                                   mode="lines", y=[points.iloc[0, 1], points.iloc[1, 1]],
                                                   z=[points.iloc[0, 2], points.iloc[1, 2]]))
                if key == "ml":
                    fig['data'][-1]['line']['color'] = "#ff0000"
                else:
                    fig['data'][-1]['line']['color'] = "#0000ff"
                    fig['data'][-1]['line']['dash'] = "dash"

    fig.update_layout(showlegend=False)
    fig.update(layout_coloraxis_showscale=False)
    if not filename:
        fig.show()
    else:
        fig.write_html(filename)

In [33]:
# Choose a dataset from https://clustering-benchmarks.gagolewski.com/weave/data-v1.html
dataset = clustbench.load_dataset("fcps", "lsun", path="clustering-data-v1")
labels = dataset.labels[0] - 1 # correspondance between clustbench and Python indexing

In [34]:
algo = COPKMeans(n_clusters=dataset.n_clusters[0])
algo.fit(dataset.data)
init_partition = algo.labels_
print(adjusted_rand_score(labels, algo.labels_))  # ARI entre la partition initiale et la vérité terrain : qualité
plot(dataset.data, algo.labels_)

0.4215612368745182


In [35]:
#Test run
qs = FFQS()

constraints = qs.fit(dataset.data, MLCLOracle(truth=labels, budget=20), partition=algo.labels_)
plot(dataset.data, algo.labels_, constraints)
print("ML : ", constraints["ml"])
print("CL : ", constraints["cl"])
algo.fit(dataset.data, ml=constraints["ml"], cl=constraints["cl"])
plot(dataset.data, algo.labels_)

ML :  [(272, 289), (272, 219), (272, 208), (272, 249), (272, 274), (289, 272), (289, 219), (289, 208), (289, 249), (289, 274), (219, 272), (219, 289), (219, 208), (219, 249), (219, 274), (208, 272), (208, 289), (208, 219), (208, 249), (208, 274), (249, 272), (249, 289), (249, 219), (249, 208), (249, 274), (274, 272), (274, 289), (274, 219), (274, 208), (274, 249), (95, 66), (95, 139), (95, 153), (95, 36), (95, 165), (95, 120), (95, 79), (66, 95), (66, 139), (66, 153), (66, 36), (66, 165), (66, 120), (66, 79), (139, 95), (139, 66), (139, 153), (139, 36), (139, 165), (139, 120), (139, 79), (153, 95), (153, 66), (153, 139), (153, 36), (153, 165), (153, 120), (153, 79), (36, 95), (36, 66), (36, 139), (36, 153), (36, 165), (36, 120), (36, 79), (165, 95), (165, 66), (165, 139), (165, 153), (165, 36), (165, 120), (165, 79), (120, 95), (120, 66), (120, 139), (120, 153), (120, 36), (120, 165), (120, 79), (79, 95), (79, 66), (79, 139), (79, 153), (79, 36), (79, 165), (79, 120), (318, 378), (318,