diff --git a/decanter_ai_sdk/client.py b/decanter_ai_sdk/client.py index ea68f20..47fea2d 100644 --- a/decanter_ai_sdk/client.py +++ b/decanter_ai_sdk/client.py @@ -520,6 +520,20 @@ def predict_ts( ) return prediction + def stop_uploading(self, id: str) -> None: + if self.api.stop_uploading(id): + logging.info("Uploading task: " + id + " stopped successfully.") + else: + logging.info("This task has already stopped or doesn't exist.") + return None + + def stop_training(self, id: str) -> None: + if self.api.stop_training(id): + logging.info("Experiment: " + id + " stopped successfully.") + else: + logging.info("This task has already stopped or doesn't exist.") + return None + def wait_for_response(self, url, id): pbar = tqdm(total=100, desc=url + " task is now pending") progress = 0 @@ -530,7 +544,9 @@ def wait_for_response(self, url, id): raise RuntimeError(res["progress_message"]) if res["status"] == "running": - pbar.set_description("[" + url + "] " + res["progress_message"]) + pbar.set_description( + "[" + url + "] " + "id: " + id + " " + res["progress_message"] + ) pbar.update(int(float(res["progress"]) * 100) - progress) progress = int(float(res["progress"]) * 100) diff --git a/decanter_ai_sdk/web_api/api.py b/decanter_ai_sdk/web_api/api.py index 9754a05..621c1a8 100644 --- a/decanter_ai_sdk/web_api/api.py +++ b/decanter_ai_sdk/web_api/api.py @@ -42,3 +42,11 @@ def get_table(self, data_id): @abstractmethod def get_model_list(self, experiment_id, query): raise NotImplementedError + + @abstractmethod + def stop_uploading(self, id): + raise NotImplementedError + + @abstractmethod + def stop_training(self, id): + raise NotImplementedError diff --git a/decanter_ai_sdk/web_api/decanter_api.py b/decanter_ai_sdk/web_api/decanter_api.py index 7df1da4..0275ffa 100644 --- a/decanter_ai_sdk/web_api/decanter_api.py +++ b/decanter_ai_sdk/web_api/decanter_api.py @@ -147,3 +147,21 @@ def get_model_list(self, experiment_id, query): # pragma: no cover verify=False, ) return res.json()["model_list"] + + def stop_uploading(self, id) -> bool: # pragma: no cover + res = requests.post( + f"{self.url}table/stop", + headers=self.auth_headers, + verify=False, + data={"table_id": id, "project_id": self.project_id}, + ) + return res.ok + + def stop_training(self, id) -> bool: # pragma: no cover + res = requests.post( + f"{self.url}experiment/stop", + headers=self.auth_headers, + verify=False, + data={"experiment_id": id, "project_id": self.project_id}, + ) + return res.ok diff --git a/decanter_ai_sdk/web_api/iid_testing_api.py b/decanter_ai_sdk/web_api/iid_testing_api.py index fc400dd..0e9abfa 100644 --- a/decanter_ai_sdk/web_api/iid_testing_api.py +++ b/decanter_ai_sdk/web_api/iid_testing_api.py @@ -74,3 +74,13 @@ def get_model_list(self, experiment_id, query): f = open(current_path + "/data/model_list.json") model_list_data = json.load(f) return model_list_data + + def stop_uploading(self, id): + if id == "": + return False + return True + + def stop_training(self, id): + if id == "": + return False + return True diff --git a/decanter_ai_sdk/web_api/ts_testing_api.py b/decanter_ai_sdk/web_api/ts_testing_api.py index 656be02..3d6146b 100644 --- a/decanter_ai_sdk/web_api/ts_testing_api.py +++ b/decanter_ai_sdk/web_api/ts_testing_api.py @@ -74,3 +74,13 @@ def get_model_list(self, experiment_id, query): f = open(current_path + "/data/model_list.json") model_list_data = json.load(f) return model_list_data + + def stop_uploading(self, id): + if id == "": + return False + return True + + def stop_training(self, id): + if id == "": + return False + return True diff --git a/examples/iid_example.py b/examples/iid_example.py index 3c844fe..e0e3292 100644 --- a/examples/iid_example.py +++ b/examples/iid_example.py @@ -16,11 +16,11 @@ def test_iid(): train_file_path = os.path.join(current_path, "../data/train.csv") train_file = open(train_file_path, "rb") - train_id = client.upload(train_file, "../data/test_file") + train_id = client.upload(train_file, "train_file") test_file_path = os.path.join(current_path, "../data/test.csv") test_file = open(test_file_path, "rb") - test_id = client.upload(test_file, "../data/test_file") + test_id = client.upload(test_file, "test_file") print("This will show top 2 uploaded table names and ids: \n") diff --git a/examples/ts_example.py b/examples/ts_example.py index dd678a1..13d900f 100644 --- a/examples/ts_example.py +++ b/examples/ts_example.py @@ -9,7 +9,7 @@ def test_iid(): auth_key = "" # TODO fill in real authorization key project_id = "" # TODO fill in real project id host = "" # TODO fill in real host - print("---From test iid---") + print("---From test ts---") client = Client(auth_key=auth_key, project_id=project_id, host=host) diff --git a/tests/test_mock_iid.py b/tests/test_mock_iid.py index 10c8d50..ac6934c 100644 --- a/tests/test_mock_iid.py +++ b/tests/test_mock_iid.py @@ -22,6 +22,9 @@ def test_iid(): train_file = open(train_file_path, "rb") train_id = client.upload(train_file, "train_file") + client.stop_uploading(train_id) + client.stop_uploading("") + test_file_path = os.path.join(current_path, "../data/test.csv") test_file = open(test_file_path, "rb") test_id = client.upload(test_file, "test_file") @@ -43,6 +46,9 @@ def test_iid(): }, ) + client.stop_training(experiment.id) + client.stop_training("") + best_model = experiment.get_best_model() assert ( experiment.get_best_model_by_metric( diff --git a/tests/test_mock_ts.py b/tests/test_mock_ts.py index 19ed2f3..6fb6c57 100644 --- a/tests/test_mock_ts.py +++ b/tests/test_mock_ts.py @@ -23,6 +23,9 @@ def test_ts(): train_file_df = pd.read_csv(open(train_file_path, "rb")) train_id = client.upload(train_file_df, "train_file") + client.stop_uploading(train_id) + client.stop_uploading("") + test_file_path = os.path.join(current_path, "../data/ts_test.csv") test_file = open(test_file_path, "rb") test_id = client.upload(test_file, "test_file") @@ -47,6 +50,9 @@ def test_ts(): custom_feature_types={"Pclass": DataType.numerical}, ) + client.stop_training(experiment.id) + client.stop_training("") + best_model = experiment.get_best_model() for metric in RegressionMetric: