diff --git a/tests/conftest.py b/tests/conftest.py index 5babd34b..178a0cd3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,7 +53,7 @@ def setup_shared_env(): del os.environ["MINIO_ENDPOINT"] -@pytest.fixture +@pytest.fixture(scope="session") def create_mock_secret(): def _get_secret(secret_name: str, secrets: dict[str, str]) -> list[Secret]: keys = [SecretKeyValue(key=k, value=v) for k, v in secrets.items()] diff --git a/tests/integrations/test_virustotal.py b/tests/integrations/test_virustotal.py index 2688d466..14b6f711 100644 --- a/tests/integrations/test_virustotal.py +++ b/tests/integrations/test_virustotal.py @@ -11,13 +11,13 @@ ) -@pytest.fixture +@pytest.fixture(scope="module") def virustotal_secret(create_mock_secret) -> dict[str, str | bytes]: mock_secret = create_mock_secret( - "virustotal", {"VIRUSTOTAL_API_KEY": os.environ["VIRUSTOTAL_API_KEY"]} + "virustotal", {"VT_API_KEY": os.environ["VT_API_KEY"]} ) - serialized_secret = mock_secret.model_dump() - return serialized_secret + mock_secret_obj = mock_secret.model_dump_json() + return mock_secret_obj @pytest.mark.parametrize( @@ -26,68 +26,49 @@ def virustotal_secret(create_mock_secret) -> dict[str, str | bytes]: "10c796b7308ac0b9c38f1caa95c798b2b28c46adaa037a9c3a9ebdd3569824e3" ], # Example hash of Mirai malware ) +@pytest.mark.respx(assert_all_mocked=False) def test_get_file_report(virustotal_secret, respx_mock, file_hash): - respx_mock.get(f'{os.environ["TRACECAT__API_URL"]}/secrets/virustotal').mock( - return_value=Response(status_code=200, json=virustotal_secret) + respx_mock.base_url = os.environ["TRACECAT__API_URL"] + respx_mock.get("/secrets/virustotal").mock( + return_value=Response(status_code=200, content=virustotal_secret) ) - result = get_file_report(file_hash) - assert result["sha256"] == file_hash - required_keys = [ - "sandbox_verdicts", - "reputation", - "last_analysis_results", - "last_analysis_stats", - ] - assert all( - key in result for key in required_keys - ), "Some keys are missing in the file report" + respx_mock.route(host="www.virustotal.com").pass_through() + result = get_file_report(file_hash).get("data") + assert result["id"] == file_hash +@pytest.mark.respx(assert_all_mocked=False) def test_get_url_report(virustotal_secret, respx_mock): - url = "http://example.com" - respx_mock.get(f'{os.environ["TRACECAT__API_URL"]}/secrets/virustotal').mock( - return_value=Response(status_code=200, json=virustotal_secret) + url = "http://example.com/" + + respx_mock.base_url = os.environ["TRACECAT__API_URL"] + respx_mock.get("/secrets/virustotal").mock( + return_value=Response(status_code=200, content=virustotal_secret) ) - result = get_url_report(url) + respx_mock.route(host="www.virustotal.com").pass_through() + result = get_url_report(url).get("data").get("attributes") assert result["url"] == url - required_keys = [ - "title", - "last_analysis_results", - "last_analysis_stats", - "total_votes", - "reputation", - ] - assert all( - key in result for key in required_keys - ), "Some keys are missing in the URL report" +@pytest.mark.respx(assert_all_mocked=False) def test_get_domain_report(virustotal_secret, respx_mock): domain = "ycombinator.com" - respx_mock.get(f'{os.environ["TRACECAT__API_URL"]}/secrets/virustotal').mock( - return_value=Response(status_code=200, json=virustotal_secret) + respx_mock.base_url = os.environ["TRACECAT__API_URL"] + respx_mock.get("/secrets/virustotal").mock( + return_value=Response(status_code=200, content=virustotal_secret) ) - result = get_domain_report(domain) - assert result["domain"] == domain - required_keys = [ - "title", - "last_analysis_results", - "last_analysis_stats", - "total_votes", - ] - assert all( - key in result for key in required_keys - ), "Some keys are missing in the domain report" + respx_mock.route(host="www.virustotal.com").pass_through() + result = get_domain_report(domain).get("data") + assert result["id"] == domain +@pytest.mark.respx(assert_all_mocked=False) def test_get_ip_address_report(virustotal_secret, respx_mock): ip = "8.8.8.8" # Google's IP - respx_mock.get(f'{os.environ["TRACECAT__API_URL"]}/secrets/virustotal').mock( - return_value=Response(status_code=200, json=virustotal_secret) + respx_mock.base_url = os.environ["TRACECAT__API_URL"] + respx_mock.get("/secrets/virustotal").mock( + return_value=Response(status_code=200, content=virustotal_secret) ) - result = get_ip_address_report(ip) + respx_mock.route(host="www.virustotal.com").pass_through() + result = get_ip_address_report(ip).get("data") assert result["id"] == ip - required_keys = ["title", "regional_internet_registry", "whois"] - assert all( - key in result for key in required_keys - ), "Some keys are missing in the IP address report" diff --git a/tracecat/integrations/virustotal.py b/tracecat/integrations/virustotal.py index 7f2c1745..8e1269e0 100644 --- a/tracecat/integrations/virustotal.py +++ b/tracecat/integrations/virustotal.py @@ -33,7 +33,7 @@ def get_file_report(file_hash: str) -> dict[str, Any]: """Returns File object: https://docs.virustotal.com/reference/files""" with create_virustotal_client() as client: rsp = client.get( - f"urls/{file_hash}", headers={"x-apikey": os.environ["VT_API_KEY"]} + f"files/{file_hash}", headers={"x-apikey": os.environ["VT_API_KEY"]} ) rsp.raise_for_status() return rsp.json()