In [None]:
import torch
import hdbscan
from datasets import load_dataset
import numpy as np
from sentence_transformers import SentenceTransformer
import pandas as pd
import collections


In [None]:

dataset = load_dataset("ag_news")
train_dataset = dataset['train'].select(range(0,100))
st = SentenceTransformer('all-roberta-large-v1')
embeddings = st.encode(train_dataset['text'])
clusterer = hdbscan.HDBSCAN(leaf_size=10, min_cluster_size=2)
clusters = clusterer.fit_predict(embeddings)

In [None]:

probs = np.array([(np.arange(0,100,1) * np.arange(100,0,-1) / 1000), (np.arange(0,100,1) * np.arange(100,0,-1) / 2000),
         (np.arange(0,100,1) * np.arange(100,0,-1) / 2000),(np.arange(0,100,1) * 2.7*np.arange(100,0,-1) / 2000)]).reshape(100,4)
labels = dataset['train'].select(range(0,100))['label']




In [None]:

    
def _get_intraclass_clustering_data(text_list,probabilities,true_labels,embeddings,n,clusterer='hdbscan',leaf_size=10,min_cluster_size=10):
    df_probs = pd.DataFrame(probabilities)
    label_results = df_probs.apply(lambda row: row.idxmax(), axis=1).to_list()
    prob_results = df_probs.apply(lambda row: row.max(), axis=1).to_list()

    unique_labels = list(set(label_results))
    unique_labels.sort()
    print(unique_labels)
    all_labels_selected_data = []
    for label in unique_labels:
        this_label_selected_data = []
        this_label_indexes = [i for i in range(len(label_results)) if label_results[i] == label]
        # print(this_label_indexes)
        this_label_text_list =  [text_list[i] for i in this_label_indexes]
        this_label_embeddings =  [embeddings[i] for i in this_label_indexes]
        this_label_probs =  [prob_results[i] for i in this_label_indexes]
        this_label_true_labels = [true_labels[i] for i in this_label_indexes]
        this_label_label_results = [label_results[i] for i in this_label_indexes]


        print("Clustering class {}.".format(label))
        # logger.info("Clustering class {}.")

        this_label_clusters = _clusterer_fit_predict(clusterer, this_label_embeddings, leaf_size, min_cluster_size) 
        # print(len(this_label_clusters),len(this_label_indexes),len(this_label_text_list),len(this_label_embeddings),len(this_label_probs))
        unique_clusters = list(set(this_label_clusters))
        unique_clusters.sort()
        # print(unique_clusters)
        all_clusters_sorted_lists = []

        # organize by sorting sorting and zipping lists, 1 list for each cluster found
        for cluster in unique_clusters:
            this_cluster_indexes = [i for i in range(len(this_label_clusters)) if this_label_clusters[i] == cluster]
            this_cluster_probs =  [this_label_probs[i] for i in this_cluster_indexes]
            this_cluster_texts = [this_label_text_list[i] for i in this_cluster_indexes]
            this_cluster_true_labels = [this_label_true_labels[i] for i in this_cluster_indexes]
            this_cluster_label_results = [this_label_label_results[i] for i in this_cluster_indexes]
            zipped_lists = (list(zip(this_cluster_probs,this_cluster_indexes,this_cluster_true_labels,this_cluster_label_results,this_cluster_texts)))
            zipped_lists.sort(reverse=True)
            # print(zipped_lists)

            all_clusters_sorted_lists.append(zipped_lists)
        # selects data iteratively, 1 from each cluster from biggest to smallest cluster, 
        # following highest probability order inside each cluster
        while len(this_label_selected_data) < n:
            for sorted_list in all_clusters_sorted_lists:
                # print((all_clusters_sorted_lists))
                if len(sorted_list) > 0:
                    # print(sorted_list)
                    selected_element = sorted_list[0]
                    print(label,selected_element)
                    this_label_selected_data.append(selected_element)
                    sorted_list.pop(0)
                    # print(sorted_list)
                    if len(this_label_selected_data) == n:
                        break

                # print(all_clusters_sorted_lists)
            if len(all_clusters_sorted_lists) == 0 or all_clusters_sorted_lists == [[]]:
                # print(label)
                print("Not enough data to sample for label {label}: {n} samples expected, but only got {this_label_n}".format(label=label,n=n,this_label_n=len(this_label_selected_data)))
                # logger.info("Not enough data to sample for label {label}: {n} samples expected, but only got {this_label_n}".format(label=label,n=n,this_label_n=len(this_label_selected_data)))
                break
        all_labels_selected_data.append(this_label_selected_data)

    flat_selected_data = [item for sublist in all_labels_selected_data for item in sublist]

    probs,train_indices,true_labels,train_labels,texts = zip(*flat_selected_data)


    # x_train = [text_list[i] for i in train_indices]
    # y_train = [true_labels[i] for i in train_indices]
    # labels_train = [label_results[i] for i in train_indices]
    # print(x_train,y_train,labels_train)

    x_train = texts
    y_train = true_labels
    labels_train = train_labels 
    print(x_train,y_train,labels_train)
                                 
    return x_train, y_train, labels_train

def _clusterer_fit_predict(clusterer,embeddings,leaf_size,min_cluster_size):
    if clusterer=='hdbscan':
        clusterer = hdbscan.HDBSCAN(leaf_size=leaf_size, min_cluster_size=min_cluster_size)
    # print(len(embeddings))
    clusters = clusterer.fit_predict(embeddings)
    # logger.info("Found {} clusters.".format(len(list(set(clusters)))))
    print("Found {} clusters.".format(len(list(set(clusters)))))
    return clusters



x_train, y_train, labels_train = _get_intraclass_clustering_data(train_dataset['text'],probs,labels,embeddings,8,min_cluster_size=2)