From 486a629ea4d5c5150f452d0b0a196bf71fd2021e Mon Sep 17 00:00:00 2001 From: Jinho Suh <83969361+nv-jinhosuh@users.noreply.github.com> Date: Thu, 15 Feb 2024 14:46:50 -0600 Subject: [PATCH] Hotfix: DLRMv2 Audit Test01 fallback failure (#1626) * Hotfix: DLRMv2 Audit Test01 fallback failure DLRMv2 Audit TEST01 may go to fallback route and the accuracy check script (accuracy-dlrm.py) didn't expect this to happen. It always expects entire sample set to be in the accuracy log while Audit TEST01 would generate subset only. This fixes the Audit TEST01 failure described above. * typo fix --- recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py b/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py index 873f6a0e1..ce662071f 100644 --- a/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py +++ b/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py @@ -43,7 +43,9 @@ def get_targets(args, qsl_indices): with open(args.aggregation_trace_file) as f: for line in f: sample_boundaries.append(sample_boundaries[-1] + int(line.split(", ")[2])) - assert len(sample_boundaries) == len(qsl_indices) + 1, "Number of samples in trace file does not match number of samples in loadgen accuracy log!" + if len(sample_boundaries) != len(qsl_indices) + 1: + print("Warning: number of samples in trace file ({}) does not match number of samples ({}) in " + "loadgen accuracy log!".format(len(sample_boundaries)-1, len(qsl_indices))) # Get all the ground truth labels in the original order in day_23 print("Parsing ground truth labels from day_23 file...") ground_truths = []