From 6f92e9d240d37a3f1c314d0c44732b15f4b5cb13 Mon Sep 17 00:00:00 2001 From: Adilkhan Sarsen <54854336+adolkhan@users.noreply.github.com> Date: Wed, 22 Nov 2023 01:31:37 +0600 Subject: [PATCH] Deep Memory recall fix (#2698) fixing deep memory recall --------- Co-authored-by: adolkhan --- deeplake/client/test_client.py | 50 +++++++++++++++++-- deeplake/client/utils.py | 8 +-- .../deep_memory/precomputed_jobs_list.txt | 4 +- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/deeplake/client/test_client.py b/deeplake/client/test_client.py index 67f7f0f09d..c2c3051fad 100644 --- a/deeplake/client/test_client.py +++ b/deeplake/client/test_client.py @@ -128,6 +128,19 @@ class Status: "--------------------------------------------------------------\n\n\n" ) + completed_no_improvement = ( + "--------------------------------------------------------------\n" + "| 1338464cd80cab681bfcfw23 |\n" + "--------------------------------------------------------------\n" + "| status | completed |\n" + "--------------------------------------------------------------\n" + "| progress | eta: 100.3 seconds |\n" + "| | recall@10: 100.0% (+0.0%) |\n" + "--------------------------------------------------------------\n" + "| results | recall@10: 100.0% (+0.0%) |\n" + "--------------------------------------------------------------\n\n\n" + ) + failed = ( "--------------------------------------------------------------\n" "| 1338464cd80cab681bfcfff3 |\n" @@ -168,7 +181,7 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list): progress=None, ) response_schema = JobResponseStatusSchema(response=pending_response) - response_schema.print_status(job_id, recall=None, importvement=None) + response_schema.print_status(job_id, recall=None, improvement=None) captured = capsys.readouterr() assert captured.out == Status.pending @@ -176,7 +189,7 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list): job_id = "3218464cd80cab681bfcfff3" training_response = create_response(job_id=job_id) response_schema = JobResponseStatusSchema(response=training_response) - response_schema.print_status(job_id, recall="85.5", importvement="2.6") + response_schema.print_status(job_id, recall="85.5", improvement="2.6") captured = capsys.readouterr() assert captured.out == Status.training @@ -187,10 +200,36 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list): status="completed", ) response_schema = JobResponseStatusSchema(response=completed_response) - response_schema.print_status(job_id, recall="85.5", importvement="2.6") + response_schema.print_status(job_id, recall="85.5", improvement="2.6") captured = capsys.readouterr() assert captured.out == Status.completed + job_id = "1338464cd80cab681bfcfw23" + completed_no_improvement_response = create_response( + job_id=job_id, + status="completed", + progress={ + "eta": 100.34, + "last_update_at": "2021-08-31T15:00:00.000000", + "error": None, + "train_recall@10": "87.8%", + "best_recall@10": "100.0% (+0.0)%", + "epoch": 0, + "base_val_recall@10": 0.8292181491851807, + "val_recall@10": "85.5%", + "dataset": "query", + "split": 0, + "loss": -0.05437087118625641, + "delta": 2.572011947631836, + }, + ) + response_schema = JobResponseStatusSchema( + response=completed_no_improvement_response + ) + response_schema.print_status(job_id, recall="0.0", improvement="0.0") + captured = capsys.readouterr() + assert captured.out == Status.completed_no_improvement + # for jobs that failed job_id = "1338464cd80cab681bfcfff3" failed_response = create_response( @@ -204,7 +243,7 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list): }, ) response_schema = JobResponseStatusSchema(response=failed_response) - response_schema.print_status(job_id, recall=None, importvement=None) + response_schema.print_status(job_id, recall=None, improvement=None) captured = capsys.readouterr() assert captured.out == Status.failed @@ -213,18 +252,21 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list): training_response, completed_response, failed_response, + completed_no_improvement_response, ] recalls = { "1238464cd80cab681bfcfff3": None, "3218464cd80cab681bfcfff3": "85.5", "2138464cd80cab681bfcfff3": "85.5", "1338464cd80cab681bfcfff3": None, + "1338464cd80cab681bfcfw23": "0.0", } improvements = { "1238464cd80cab681bfcfff3": None, "3218464cd80cab681bfcfff3": "2.6", "2138464cd80cab681bfcfff3": "2.6", "1338464cd80cab681bfcfff3": None, + "1338464cd80cab681bfcfw23": "0.0", } response_schema = JobResponseStatusSchema(response=responses) output_str = response_schema.print_jobs( diff --git a/deeplake/client/utils.py b/deeplake/client/utils.py index b488e87284..f364ff08e7 100644 --- a/deeplake/client/utils.py +++ b/deeplake/client/utils.py @@ -145,7 +145,7 @@ def print_status( self, job_id: Union[str, List[str]], recall: str, - importvement: str, + improvement: str, ): if not isinstance(job_id, List): job_id = [job_id] @@ -161,7 +161,7 @@ def print_status( indent=" " * 30, add_vertical_bars=True, recall=recall, - improvement=importvement, + improvement=improvement, ) print(line) @@ -174,7 +174,7 @@ def print_status( " " * 30, add_vertical_bars=True, recall=recall, - improvement=importvement, + improvement=improvement, ) progress_string = "| {:<27}| {:<30}" if progress == "None": @@ -298,6 +298,8 @@ def get_best_recall_improvement(recall, improvement, best_recall): elif float(improvement) < float(bimprovement): return brecall, bimprovement else: + if brecall > recall: + return brecall, bimprovement return recall, improvement diff --git a/deeplake/tests/dummy_data/deep_memory/precomputed_jobs_list.txt b/deeplake/tests/dummy_data/deep_memory/precomputed_jobs_list.txt index 18c8b2c06c..8a096e324a 100644 --- a/deeplake/tests/dummy_data/deep_memory/precomputed_jobs_list.txt +++ b/deeplake/tests/dummy_data/deep_memory/precomputed_jobs_list.txt @@ -5,4 +5,6 @@ ID STATUS RESULTS PROGRESS 2138464cd80cab681bfcfff3 completed recall@10: 85.5% (+2.6%) eta: 100.3 seconds recall@10: 85.5% (+2.6%) 1338464cd80cab681bfcfff3 failed not available yet eta: None seconds - error: list indices must beintegers or slices,not str \ No newline at end of file + error: list indices must beintegers or slices,not str +1338464cd80cab681bfcfw23 completed recall@10: 100.0% (+0.0%) eta: 100.3 seconds + recall@10: 100.0% (+0.0%) \ No newline at end of file