Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 0 additions & 62 deletions .github/workflows/codeql.yml

This file was deleted.

41 changes: 38 additions & 3 deletions tests/unit/test_ai_rules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast
import json
import os
import re
import sys
import tempfile
import textwrap
Expand All @@ -24,9 +25,9 @@ def _wrap(code: str) -> str:
return f"def _load_model():\n{indented}\n"


def _ai206_rule() -> dict:
def _ai_rule(rule_id: str) -> dict:
rules = toml.loads(RULES_PATH.read_text(encoding="utf-8"))
return next(rule for rule in rules["rule"] if rule["id"] == "AI206")
return next(rule for rule in rules["rule"] if rule["id"] == rule_id)


def _ast_node(node: ast.AST) -> dict:
Expand Down Expand Up @@ -104,7 +105,7 @@ def fires(code: str, rule_id: str) -> bool:

class TestAI206:
def test_rule_metadata(self):
rule = _ai206_rule()
rule = _ai_rule("AI206")
assert rule["severity"] == "High"
assert rule["cwe"] == "CWE-94"
assert rule["ast_match"] == AI206_MATCHER
Expand Down Expand Up @@ -132,3 +133,37 @@ def test_trust_remote_code_false_safe(self):
)
"""
assert not fires(code, "AI206")


class TestAI202:
def test_rule_metadata(self):
rule = _ai_rule("AI202")
assert rule["severity"] == "High"
assert rule["cwe"] == "CWE-502"
assert rule["pattern"] == r"torch\.load\s*\("
assert rule["exclude_pattern"] == r"^\s*#|weights_only\s*=\s*True"

@pytest.mark.parametrize(
"code",
[
'model = torch.load("model.pt")',
"checkpoint = torch.load(path, map_location='cpu')",
],
)
def test_pattern_matches_torch_load_calls(self, code):
rule = _ai_rule("AI202")
assert re.search(rule["pattern"], code)
assert not re.search(rule["exclude_pattern"], code)

@pytest.mark.parametrize(
"code",
[
'# model = torch.load("model.pt")',
'model = torch.load("model.pt", weights_only=True)',
'model = torch.load("model.pt", weights_only = True)',
],
)
def test_exclude_pattern_suppresses_safe_or_comment_cases(self, code):
rule = _ai_rule("AI202")
assert re.search(rule["pattern"], code)
assert re.search(rule["exclude_pattern"], code)