Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/pro-1286-fit_predict_update-vs #469

Merged
merged 15 commits into from Feb 27, 2022
152 changes: 146 additions & 6 deletions relevanceai/clusterer/clusterer.py
Expand Up @@ -846,6 +846,152 @@ def delete_centroids(self, dataset: Union[str, Dataset], vector_fields: List):
)
return response.json()["status"]

def fit_predict(
self,
data: Union[str, Dataset, List[Dict]],
vector_fields: List[str],
filters: Optional[List[Dict]] = None,
return_only_clusters: bool = True,
include_grade: bool = False,
update: bool = True,
inplace: bool = False,
):
"""
Parameters
----------
data: Union[str, Dataset, List[Dict]]
Either a reference to a Relevance AI Dataset, be it its name
(string) or the object itself (Dataset), or a list of documents
(List[Dict]).

vector_fields: List[str]
The vector fields over which to fit the model.

filters: List[Dict]
A list of filters to enable for document retrieval. This only
applies to a reference to a Relevance AI Dataset.

return_only_clusters: bool
An indicator that determines what is returned. If True, this
function returns the clusters. Else, the function returns the
original documents.

include_grade: bool
An indictor that determines whether to include (True) a grade
base on the mean silhouette score or not (False).

update: bool
An indicator that determines whether to update the documents
that were part of the clustering process. This only applies to a
reference to a Relevance AI Dataset.

inplace: bool
An indicator that determines whether the documents are edited
inplace (True) or a copy is created and edited (False).

Example
-------

.. code-block::

from relevanceai import ClusterBase, Client
ofrighil marked this conversation as resolved.
Show resolved Hide resolved

client = Client()

import random
class CustomClusterModel(ClusterBase):
def fit_predict(self, X):
cluster_labels = [random.randint(0, 100) for _ in range(len(X))]
return cluster_labels

model = CustomClusterModel()

df = client.Dataset("sample_dataset")
clusterer = client.ClusterOps(alias="random_clustering", model=model)
clusterer.fit_predict_update(df, vector_fields=["sample_vector_"])

"""
filters = [] if filters is None else filters

if update and isinstance(data, list):
warnings.warn(
"Cannot update list of datasets that are untethered "
"to a Relevance AI dataset. "
"Setting update to False."
)
# If data is of type List[Dict] the value of update doesn't
# actually matter. This is more for good practice.
update = False

if isinstance(data, list):
documents = data
else:
self._init_dataset(data)
self.vector_fields = vector_fields
# make sure to only get fields where vector fields exist
filters.extend(
[
{
"field": f,
"filter_type": "exists",
"condition": "==",
"condition_value": " ",
"strict": "must_or",
}
for f in vector_fields
]
)
# load the documents
self.logger.warning(
"Retrieving documents... This can take a while if the dataset is large."
)
print("Retrieving all documents")
documents = self._get_all_documents(
dataset_id=self.dataset_id, filters=filters, select_fields=vector_fields
)
if len(documents) == 0:
raise NoDocumentsError()

vectors = self._get_vectors_from_documents(vector_fields, documents)

# Label the clusters
print("Fitting and predicting on all relevant documents")
cluster_labels = self._label_clusters(self.model.fit_predict(vectors))

if include_grade:
try:
self._calculate_silhouette_grade(vectors, cluster_labels)
except Exception as e:
print(e)
pass

clustered_documents = self.set_cluster_labels_across_documents(
cluster_labels,
documents,
inplace=inplace,
return_only_clusters=return_only_clusters,
)

if not isinstance(data, list):
if update:
# Updating the db
print("Updating the database...")
results = self._update_documents(
self.dataset_id, clustered_documents, chunksize=10000
)
self.logger.info(results)

# Update the centroid collection
self.model.vector_fields = vector_fields

self._insert_centroid_documents()
print(
"Build your clustering app here: "
+ f"https://cloud.relevance.ai/dataset/{self.dataset_id}/deploy/recent/cluster"
)

return clustered_documents

@track
def fit_predict_update(
self,
Expand Down Expand Up @@ -945,12 +1091,6 @@ def fit_predict(self, X):
+ f"https://cloud.relevance.ai/dataset/{self.dataset_id}/deploy/recent/cluster"
)

@track
def fit_predict(self, X):
# If dataset, runs fit predict on a dataset
# if docs, runs fit predict on a set of document
pass

@track
def fit_dataset(
self,
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_clusterer_api/test_clusterer.py
Expand Up @@ -7,6 +7,7 @@
from relevanceai.clusterer import kmeans_clusterer
from relevanceai.http_client import Dataset, Client, ClusterOps
from relevanceai.dataset_api.cluster_groupby import ClusterGroupby
from relevanceai.vector_tools.cluster import ClusterBase

CLUSTER_ALIAS = "kmeans_10"
VECTOR_FIELDS = ["sample_1_vector_"]
Expand Down Expand Up @@ -68,3 +69,22 @@ def test_agg_std(test_clusterer: ClusterOps):
groupby_agg = cluster_groupby.agg({"sample_2_value": "std_deviation"})
assert isinstance(groupby_agg, dict)
assert len(groupby_agg) > 0


def test_fit_predict(test_client: Client, vector_dataset_id: str):
import random

class CustomClusterModel(ClusterBase):
def fit_predict(self, X):
cluster_labels = [random.randint(0, 100) for _ in range(len(X))]
return cluster_labels

model = CustomClusterModel()

df = test_client.Dataset(vector_dataset_id)
clusterer = test_client.ClusterOps(
alias="random_clustering",
model=model,
)
clusterer.fit_predict(df, vector_fields=["sample_1_vector_"])
assert "_cluster_.sample_1_vector_.random_clustering" in df.schema