diff --git a/stoq/core.py b/stoq/core.py index 14fc30b..139dc41 100644 --- a/stoq/core.py +++ b/stoq/core.py @@ -632,6 +632,12 @@ async def scan_request( else: payload_idx = hashes_seen[payload_hash] for idx in payload_idx: + if request.payloads[idx].results.payload_id in extracted_payload.results.extracted_from: + # Handle self extracting payload + dup_idx = extracted_payload.results.extracted_from.index(request.payloads[idx].results.payload_id) + del extracted_payload.results.extracted_by[dup_idx] + del extracted_payload.results.extracted_from[dup_idx] + self.log.debug(f'Self extracting payload detected for {request.payloads[idx].results.payload_id}') request.payloads[idx].results.extracted_by.extend( extracted_payload.results.extracted_by ) diff --git a/stoq/tests/test_core.py b/stoq/tests/test_core.py index c28422a..2dbcee3 100644 --- a/stoq/tests/test_core.py +++ b/stoq/tests/test_core.py @@ -298,7 +298,7 @@ async def test_scan_with_duplicate_extracted_payloads(self): 'extract_payload', response.results[1].plugins_run['workers'][0] ) self.assertEqual('simple_worker', response.results[1].extracted_by[0]) - self.assertEqual('extract_payload', response.results[1].extracted_by[1]) + self.assertEqual(1, len(response.results[1].extracted_by)) async def test_scan_with_nested_required_plugin(self): s = Stoq(base_dir=utils.get_data_dir()) @@ -507,6 +507,10 @@ async def test_dedup(self): response = await s.scan(self.generic_content) self.assertEqual(len(response.results), 2) + # Make sure extracted_from does not include the result payload_id + for result in response.results: + self.assertNotIn(result.payload_id, result.extracted_from) + @asynctest.skipIf( sys.version_info >= (3, 8), 'skipping because python >= 3.8 breaks test' )