Skip to content

Commit

Permalink
Merge pull request #469 from RelevanceAI/feature/pro-1286-fit_predict…
Browse files Browse the repository at this point in the history
…_update-vs

feature/pro-1286-fit_predict_update-vs
  • Loading branch information
boba-and-beer committed Feb 27, 2022
2 parents 473e2f3 + 5f4e10c commit d0d7cf0
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 6 deletions.
152 changes: 146 additions & 6 deletions relevanceai/clusterer/clusterer.py
Expand Up @@ -846,6 +846,152 @@ def delete_centroids(self, dataset: Union[str, Dataset], vector_fields: List):
)
return response.json()["status"]

def fit_predict(
self,
data: Union[str, Dataset, List[Dict]],
vector_fields: List[str],
filters: Optional[List[Dict]] = None,
return_only_clusters: bool = True,
include_grade: bool = False,
update: bool = True,
inplace: bool = False,
):
"""
Parameters
----------
data: Union[str, Dataset, List[Dict]]
Either a reference to a Relevance AI Dataset, be it its name
(string) or the object itself (Dataset), or a list of documents
(List[Dict]).
vector_fields: List[str]
The vector fields over which to fit the model.
filters: List[Dict]
A list of filters to enable for document retrieval. This only
applies to a reference to a Relevance AI Dataset.
return_only_clusters: bool
An indicator that determines what is returned. If True, this
function returns the clusters. Else, the function returns the
original documents.
include_grade: bool
An indictor that determines whether to include (True) a grade
base on the mean silhouette score or not (False).
update: bool
An indicator that determines whether to update the documents
that were part of the clustering process. This only applies to a
reference to a Relevance AI Dataset.
inplace: bool
An indicator that determines whether the documents are edited
inplace (True) or a copy is created and edited (False).
Example
-------
.. code-block::
from relevanceai import ClusterBase, Client
client = Client()
import random
class CustomClusterModel(ClusterBase):
def fit_predict(self, X):
cluster_labels = [random.randint(0, 100) for _ in range(len(X))]
return cluster_labels
model = CustomClusterModel()
df = client.Dataset("sample_dataset")
clusterer = client.ClusterOps(alias="random_clustering", model=model)
clusterer.fit_predict_update(df, vector_fields=["sample_vector_"])
"""
filters = [] if filters is None else filters

if update and isinstance(data, list):
warnings.warn(
"Cannot update list of datasets that are untethered "
"to a Relevance AI dataset. "
"Setting update to False."
)
# If data is of type List[Dict] the value of update doesn't
# actually matter. This is more for good practice.
update = False

if isinstance(data, list):
documents = data
else:
self._init_dataset(data)
self.vector_fields = vector_fields
# make sure to only get fields where vector fields exist
filters.extend(
[
{
"field": f,
"filter_type": "exists",
"condition": "==",
"condition_value": " ",
"strict": "must_or",
}
for f in vector_fields
]
)
# load the documents
self.logger.warning(
"Retrieving documents... This can take a while if the dataset is large."
)
print("Retrieving all documents")
documents = self._get_all_documents(
dataset_id=self.dataset_id, filters=filters, select_fields=vector_fields
)
if len(documents) == 0:
raise NoDocumentsError()

vectors = self._get_vectors_from_documents(vector_fields, documents)

# Label the clusters
print("Fitting and predicting on all relevant documents")
cluster_labels = self._label_clusters(self.model.fit_predict(vectors))

if include_grade:
try:
self._calculate_silhouette_grade(vectors, cluster_labels)
except Exception as e:
print(e)
pass

clustered_documents = self.set_cluster_labels_across_documents(
cluster_labels,
documents,
inplace=inplace,
return_only_clusters=return_only_clusters,
)

if not isinstance(data, list):
if update:
# Updating the db
print("Updating the database...")
results = self._update_documents(
self.dataset_id, clustered_documents, chunksize=10000
)
self.logger.info(results)

# Update the centroid collection
self.model.vector_fields = vector_fields

self._insert_centroid_documents()
print(
"Build your clustering app here: "
+ f"https://cloud.relevance.ai/dataset/{self.dataset_id}/deploy/recent/cluster"
)

return clustered_documents

@track
def fit_predict_update(
self,
Expand Down Expand Up @@ -945,12 +1091,6 @@ def fit_predict(self, X):
+ f"https://cloud.relevance.ai/dataset/{self.dataset_id}/deploy/recent/cluster"
)

@track
def fit_predict(self, X):
# If dataset, runs fit predict on a dataset
# if docs, runs fit predict on a set of document
pass

@track
def fit_dataset(
self,
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_clusterer_api/test_clusterer.py
Expand Up @@ -7,6 +7,7 @@
from relevanceai.clusterer import kmeans_clusterer
from relevanceai.http_client import Dataset, Client, ClusterOps
from relevanceai.dataset_api.cluster_groupby import ClusterGroupby
from relevanceai.vector_tools.cluster import ClusterBase

CLUSTER_ALIAS = "kmeans_10"
VECTOR_FIELDS = ["sample_1_vector_"]
Expand Down Expand Up @@ -68,3 +69,22 @@ def test_agg_std(test_clusterer: ClusterOps):
groupby_agg = cluster_groupby.agg({"sample_2_value": "std_deviation"})
assert isinstance(groupby_agg, dict)
assert len(groupby_agg) > 0


def test_fit_predict(test_client: Client, vector_dataset_id: str):
import random

class CustomClusterModel(ClusterBase):
def fit_predict(self, X):
cluster_labels = [random.randint(0, 100) for _ in range(len(X))]
return cluster_labels

model = CustomClusterModel()

df = test_client.Dataset(vector_dataset_id)
clusterer = test_client.ClusterOps(
alias="random_clustering",
model=model,
)
clusterer.fit_predict(df, vector_fields=["sample_1_vector_"])
assert "_cluster_.sample_1_vector_.random_clustering" in df.schema

0 comments on commit d0d7cf0

Please sign in to comment.