diff --git a/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py b/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py index 873f6a0e1..ce662071f 100644 --- a/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py +++ b/recommendation/dlrm_v2/pytorch/tools/accuracy-dlrm.py @@ -43,7 +43,9 @@ def get_targets(args, qsl_indices): with open(args.aggregation_trace_file) as f: for line in f: sample_boundaries.append(sample_boundaries[-1] + int(line.split(", ")[2])) - assert len(sample_boundaries) == len(qsl_indices) + 1, "Number of samples in trace file does not match number of samples in loadgen accuracy log!" + if len(sample_boundaries) != len(qsl_indices) + 1: + print("Warning: number of samples in trace file ({}) does not match number of samples ({}) in " + "loadgen accuracy log!".format(len(sample_boundaries)-1, len(qsl_indices))) # Get all the ground truth labels in the original order in day_23 print("Parsing ground truth labels from day_23 file...") ground_truths = []