From 8e69f22b245288ba58fa8541ef2e49b487314f6e Mon Sep 17 00:00:00 2001 From: Ahmed Lekssays Date: Wed, 8 Oct 2025 10:36:43 +0300 Subject: [PATCH 1/3] add support for swift, ruby, ghidra, csharp, and php --- README.md | 4 ++-- config.example.yaml | 5 +++++ src/models.py | 2 +- src/services/cpg_generator.py | 6 ++++++ src/tools/mcp_tools.py | 2 +- src/utils/validators.py | 2 +- 6 files changed, 16 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 559b37c..fc0c39c 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ A Model Context Protocol (MCP) server that provides AI assistants with static co ## Features -- **Multi-Language Support**: Java, C/C++, JavaScript, Python, Go, Kotlin, Swift +- **Multi-Language Support**: Java, C/C++, JavaScript, Python, Go, Kotlin, C#, Ghidra, Jimple, PHP, Ruby, Swift - **Docker Isolation**: Each analysis session runs in a secure container - **GitHub Integration**: Analyze repositories directly from GitHub URLs - **Session-Based**: Persistent CPG sessions with automatic cleanup @@ -150,7 +150,7 @@ sessions: cpg: generation_timeout: 600 # CPG generation timeout (seconds) - supported_languages: [java, c, cpp, javascript, python, go, kotlin, swift] + supported_languages: [java, c, cpp, javascript, python, go, kotlin, csharp, ghidra, jimple, php, ruby, swift] ``` Environment variables override config file settings (e.g., `MCP_HOST`, `REDIS_HOST`, `SESSION_TTL`). diff --git a/config.example.yaml b/config.example.yaml index 581b545..84b4d00 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -29,6 +29,11 @@ cpg: - python - go - kotlin + - csharp + - ghidra + - jimple + - php + - ruby - swift query: diff --git a/src/models.py b/src/models.py index bc53d78..29a81fa 100644 --- a/src/models.py +++ b/src/models.py @@ -129,7 +129,7 @@ class CPGConfig: generation_timeout: int = 600 # 10 minutes max_repo_size_mb: int = 500 supported_languages: List[str] = field(default_factory=lambda: [ - "java", "c", "cpp", "javascript", "python", "go", "kotlin" + "java", "c", "cpp", "javascript", "python", "go", "kotlin", "csharp", "ghidra", "jimple", "php", "ruby", "swift" ]) diff --git a/src/services/cpg_generator.py b/src/services/cpg_generator.py index 79d1650..ef4479a 100644 --- a/src/services/cpg_generator.py +++ b/src/services/cpg_generator.py @@ -26,6 +26,12 @@ class CPGGenerator: "python": "pysrc2cpg", "go": "gosrc2cpg", "kotlin": "kotlin2cpg", + "csharp": "csharpsrc2cpg", + "ghidra": "ghidra2cpg", + "jimple": "jimple2cpg", + "php": "php2cpg", + "ruby": "rubysrc2cpg", + "swift": "swiftsrc2cpg", } def __init__( diff --git a/src/tools/mcp_tools.py b/src/tools/mcp_tools.py index e048e1d..300156f 100644 --- a/src/tools/mcp_tools.py +++ b/src/tools/mcp_tools.py @@ -67,7 +67,7 @@ async def create_cpg_session( source_type: Either "local" or "github" source_path: For local: absolute path to source directory For github: full GitHub URL (e.g., https://github.com/user/repo) - language: Programming language - one of: java, c, cpp, javascript, python, go, kotlin + language: Programming language - one of: java, c, cpp, javascript, python, go, kotlin, csharp, ghidra, jimple, php, ruby, swift github_token: GitHub Personal Access Token for private repositories (optional) branch: Specific git branch to checkout (optional, defaults to default branch) diff --git a/src/utils/validators.py b/src/utils/validators.py index 0a5fa51..981fd8c 100644 --- a/src/utils/validators.py +++ b/src/utils/validators.py @@ -21,7 +21,7 @@ def validate_source_type(source_type: str): def validate_language(language: str): """Validate programming language""" - supported = ["java", "c", "cpp", "javascript", "python", "go", "kotlin"] + supported = ["java", "c", "cpp", "javascript", "python", "go", "kotlin", "csharp", "ghidra", "jimple", "php", "ruby", "swift"] if language not in supported: raise ValidationError( f"Unsupported language '{language}'. Supported: {', '.join(supported)}" From c24389d9f2aaf27346e9048e39789179d0782351 Mon Sep 17 00:00:00 2001 From: Ahmed Lekssays Date: Wed, 8 Oct 2025 10:46:29 +0300 Subject: [PATCH 2/3] add limit to retuned queries --- src/services/query_executor.py | 22 +++++++++++++--------- src/tools/mcp_tools.py | 14 ++++++++++---- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/services/query_executor.py b/src/services/query_executor.py index 5465766..02f5840 100644 --- a/src/services/query_executor.py +++ b/src/services/query_executor.py @@ -69,9 +69,9 @@ def set_cpg_generator(self, cpg_generator): async def execute_query_async( self, session_id: str, - cpg_path: str, query: str, - timeout: Optional[int] = None + timeout: Optional[int] = None, + limit: Optional[int] = 150 ) -> str: """Execute a CPGQL query asynchronously and return query UUID""" try: @@ -85,7 +85,7 @@ async def execute_query_async( container_cpg_path = "/workspace/cpg.bin" # Normalize query to ensure JSON output and pipe to file - query_normalized = self._normalize_query_for_json(query.strip()) + query_normalized = self._normalize_query_for_json(query.strip(), limit) output_file = f"/tmp/query_{query_id}.json" query_with_pipe = f"{query_normalized} #> \"{output_file}\"" @@ -100,7 +100,7 @@ async def execute_query_async( } # Start async execution - asyncio.create_task(self._execute_query_background(query_id, session_id, container_cpg_path, query_with_pipe, timeout)) + asyncio.create_task(self._execute_query_background(query_id, session_id, container_cpg_path, query_with_pipe, timeout, limit)) logger.info(f"Started async query {query_id} for session {session_id}") return query_id @@ -113,9 +113,8 @@ async def _execute_query_background( self, query_id: str, session_id: str, - cpg_path: str, query_with_pipe: str, - timeout: Optional[int] + timeout: Optional[int], ): """Execute query in background""" try: @@ -223,7 +222,8 @@ async def execute_query( session_id: str, cpg_path: str, query: str, - timeout: Optional[int] = None + timeout: Optional[int] = None, + limit: Optional[int] = 150 ) -> QueryResult: """Execute a CPGQL query synchronously (for backwards compatibility)""" start_time = time.time() @@ -233,7 +233,7 @@ async def execute_query( validate_cpgql_query(query) # Normalize query to ensure JSON output - query_normalized = self._normalize_query_for_json(query.strip()) + query_normalized = self._normalize_query_for_json(query.strip(), limit) # Check cache if enabled if self.config.cache_enabled and self.redis: @@ -339,7 +339,7 @@ async def cleanup_old_queries(self, max_age_seconds: int = 3600): if to_cleanup: logger.info(f"Cleaned up {len(to_cleanup)} old queries") - def _normalize_query_for_json(self, query: str) -> str: + def _normalize_query_for_json(self, query: str, limit: Optional[int] = None) -> str: """Normalize query to ensure JSON output""" # Remove any existing output modifiers query = query.strip() @@ -352,6 +352,10 @@ def _normalize_query_for_json(self, query: str) -> str: elif query.endswith('.toJsonPretty'): query = query[:-13] + # Add limit if specified + if limit is not None and limit > 0: + query = f"{query}.take({limit})" + # Add .toJsonPretty for proper JSON output return query + '.toJsonPretty' diff --git a/src/tools/mcp_tools.py b/src/tools/mcp_tools.py index 300156f..b0f04b7 100644 --- a/src/tools/mcp_tools.py +++ b/src/tools/mcp_tools.py @@ -332,7 +332,8 @@ async def generate_and_cache(): async def run_cpgql_query_async( session_id: str, query: str, - timeout: int = 30 + timeout: int = 30, + limit: Optional[int] = 150 ) -> Dict[str, Any]: """ Executes a CPGQL query asynchronously and returns a query ID for status tracking. @@ -345,6 +346,7 @@ async def run_cpgql_query_async( session_id: The session ID returned from create_cpg_session query: CPGQL query string (automatically converted to JSON output) timeout: Maximum execution time in seconds (default: 30) + limit: Maximum number of results to return (default: 150) Returns: { @@ -381,7 +383,8 @@ async def run_cpgql_query_async( session_id=session_id, cpg_path=session.cpg_path, query=query, - timeout=timeout + timeout=timeout, + limit=limit ) return { @@ -767,7 +770,8 @@ async def cleanup_queries( async def run_cpgql_query( session_id: str, query: str, - timeout: int = 30 + timeout: int = 30, + limit: Optional[int] = 150 ) -> Dict[str, Any]: """ Executes a CPGQL query synchronously on a loaded CPG. @@ -780,6 +784,7 @@ async def run_cpgql_query( session_id: The session ID returned from create_cpg_session query: CPGQL query string (automatically converted to JSON output) timeout: Maximum execution time in seconds (default: 30) + limit: Maximum number of results to return (default: 150) Returns: { @@ -821,7 +826,8 @@ async def run_cpgql_query( session_id=session_id, cpg_path=container_cpg_path, query=query, - timeout=timeout + timeout=timeout, + limit=limit ) return { From 0a458c51a0efb3a2a700e595ee38e8945906be41 Mon Sep 17 00:00:00 2001 From: Ahmed Lekssays Date: Wed, 8 Oct 2025 10:59:56 +0300 Subject: [PATCH 3/3] add tests --- pytest.ini | 15 ++ run_tests.py | 25 +++ src/utils/redis_client.py | 6 +- tests/__init__.py | 1 + tests/test_config.py | 313 +++++++++++++++++++++++++++ tests/test_cpg_generator.py | 385 ++++++++++++++++++++++++++++++++++ tests/test_exceptions.py | 107 ++++++++++ tests/test_logging.py | 206 ++++++++++++++++++ tests/test_main.py | 117 +++++++++++ tests/test_models.py | 271 ++++++++++++++++++++++++ tests/test_redis_client.py | 257 +++++++++++++++++++++++ tests/test_session_manager.py | 283 +++++++++++++++++++++++++ tests/test_utils.py | 283 +++++++++++++++++++++++++ tests/test_validators.py | 306 +++++++++++++++++++++++++++ 14 files changed, 2572 insertions(+), 3 deletions(-) create mode 100644 pytest.ini create mode 100644 run_tests.py create mode 100644 tests/__init__.py create mode 100644 tests/test_config.py create mode 100644 tests/test_cpg_generator.py create mode 100644 tests/test_exceptions.py create mode 100644 tests/test_logging.py create mode 100644 tests/test_main.py create mode 100644 tests/test_models.py create mode 100644 tests/test_redis_client.py create mode 100644 tests/test_session_manager.py create mode 100644 tests/test_utils.py create mode 100644 tests/test_validators.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..e44f55f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,15 @@ +[tool:pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + --verbose + --tb=short + --strict-markers + --disable-warnings + --asyncio-mode=auto +markers = + unit: Unit tests + integration: Integration tests + slow: Slow running tests \ No newline at end of file diff --git a/run_tests.py b/run_tests.py new file mode 100644 index 0000000..6e6a7b5 --- /dev/null +++ b/run_tests.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +""" +Test runner for Joern MCP +""" +import sys +import subprocess +from pathlib import Path + +def run_tests(): + """Run the test suite""" + project_root = Path(__file__).parent + + # Ensure we're in the project root + if not (project_root / "pyproject.toml").exists(): + print("Error: Must run from project root directory") + sys.exit(1) + + # Run pytest + cmd = [sys.executable, "-m", "pytest", "tests/"] + result = subprocess.run(cmd, cwd=project_root) + + sys.exit(result.returncode) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/src/utils/redis_client.py b/src/utils/redis_client.py index 3c1f5ed..a1fd00a 100644 --- a/src/utils/redis_client.py +++ b/src/utils/redis_client.py @@ -4,7 +4,7 @@ import json import logging from typing import Optional, Dict, Any, List -import redis.asyncio as aioredis +import redis.asyncio as redis from ..models import RedisConfig, Session @@ -16,12 +16,12 @@ class RedisClient: def __init__(self, config: RedisConfig): self.config = config - self.client: Optional[aioredis.Redis] = None + self.client: Optional[redis.Redis] = None async def connect(self): """Establish Redis connection""" try: - self.client = await aioredis.from_url( + self.client = redis.from_url( f"redis://{self.config.host}:{self.config.port}/{self.config.db}", password=self.config.password, decode_responses=self.config.decode_responses diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..902e8a8 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Test package for Joern MCP \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..aa3e1af --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,313 @@ +""" +Tests for configuration management +""" +import os +import tempfile +import yaml +import pytest +from unittest.mock import patch, mock_open +from src.config import load_config, _substitute_env_vars, _dict_to_config +from src.models import Config, ServerConfig, RedisConfig, SessionConfig, CPGConfig, QueryConfig, StorageConfig, JoernConfig +from src.exceptions import ValidationError + + +class TestLoadConfig: + """Test configuration loading""" + + def test_load_config_from_file(self): + """Test loading config from YAML file""" + config_data = { + "server": { + "host": "127.0.0.1", + "port": 8080, + "log_level": "DEBUG" + }, + "redis": { + "host": "redis-server", + "port": 6379, + "password": "secret", + "db": 1 + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config_data, f) + config_path = f.name + + try: + config = load_config(config_path) + + assert config.server.host == "127.0.0.1" + assert config.server.port == 8080 + assert config.server.log_level == "DEBUG" + assert config.redis.host == "redis-server" + assert config.redis.port == 6379 + assert config.redis.password == "secret" + assert config.redis.db == 1 + finally: + os.unlink(config_path) + + def test_load_config_from_env_vars(self): + """Test loading config from environment variables""" + env_vars = { + "MCP_HOST": "0.0.0.0", + "MCP_PORT": "4242", + "MCP_LOG_LEVEL": "INFO", + "REDIS_HOST": "localhost", + "REDIS_PORT": "6379", + "REDIS_PASSWORD": "testpass", + "REDIS_DB": "2", + "JOERN_BINARY_PATH": "/usr/bin/joern", + "JOERN_MEMORY_LIMIT": "4g", + "SESSION_TTL": "7200", + "SESSION_IDLE_TIMEOUT": "1800", + "MAX_CONCURRENT_SESSIONS": "20", + "CPG_GENERATION_TIMEOUT": "1200", + "MAX_REPO_SIZE_MB": "1000", + "QUERY_TIMEOUT": "60", + "QUERY_CACHE_ENABLED": "false", + "QUERY_CACHE_TTL": "600", + "WORKSPACE_ROOT": "/tmp/custom", + "CLEANUP_ON_SHUTDOWN": "false" + } + + with patch.dict(os.environ, env_vars): + config = load_config() + + assert config.server.host == "0.0.0.0" + assert config.server.port == 4242 + assert config.server.log_level == "INFO" + assert config.redis.host == "localhost" + assert config.redis.port == 6379 + assert config.redis.password == "testpass" + assert config.redis.db == 2 + assert config.joern.binary_path == "/usr/bin/joern" + assert config.joern.memory_limit == "4g" + assert config.sessions.ttl == 7200 + assert config.sessions.idle_timeout == 1800 + assert config.sessions.max_concurrent == 20 + assert config.cpg.generation_timeout == 1200 + assert config.cpg.max_repo_size_mb == 1000 + assert config.query.timeout == 60 + assert config.query.cache_enabled is False + assert config.query.cache_ttl == 600 + assert config.storage.workspace_root == "/tmp/custom" + assert config.storage.cleanup_on_shutdown is False + + def test_load_config_defaults(self): + """Test loading config with defaults""" + # Clear environment + with patch.dict(os.environ, {}, clear=True): + config = load_config() + + assert config.server.host == "0.0.0.0" + assert config.server.port == 4242 + assert config.server.log_level == "INFO" + assert config.redis.host == "localhost" + assert config.redis.port == 6379 + assert config.redis.password is None + assert config.redis.db == 0 + assert config.joern.binary_path == "joern" + assert config.joern.memory_limit == "4g" + assert config.sessions.ttl == 3600 + assert config.sessions.idle_timeout == 1800 + assert config.sessions.max_concurrent == 10 + assert config.cpg.generation_timeout == 600 + assert config.cpg.max_repo_size_mb == 500 + assert config.query.timeout == 30 + assert config.query.cache_enabled is True + assert config.query.cache_ttl == 300 + assert config.storage.workspace_root == "/tmp/joern-mcp" + assert config.storage.cleanup_on_shutdown is True + + def test_load_config_file_not_found(self): + """Test loading config when file doesn't exist""" + config = load_config("/nonexistent/config.yaml") + + # Should fall back to environment/defaults + assert isinstance(config, Config) + + def test_substitute_env_vars(self): + """Test environment variable substitution""" + data = { + "host": "${TEST_HOST}", + "port": 8080, + "nested": { + "path": "${TEST_PATH}", + "value": "static" + }, + "list": ["${TEST_ITEM1}", "static", "${TEST_ITEM2}"] + } + + env_vars = { + "TEST_HOST": "localhost", + "TEST_PATH": "/tmp/test", + "TEST_ITEM1": "item1", + "TEST_ITEM2": "item2" + } + + with patch.dict(os.environ, env_vars): + result = _substitute_env_vars(data) + + assert result["host"] == "localhost" + assert result["port"] == 8080 + assert result["nested"]["path"] == "/tmp/test" + assert result["nested"]["value"] == "static" + assert result["list"] == ["item1", "static", "item2"] + + def test_substitute_env_vars_with_defaults(self): + """Test environment variable substitution with defaults""" + data = { + "host": "${TEST_HOST:default_host}", + "missing": "${MISSING_VAR:default_value}" + } + + env_vars = { + "TEST_HOST": "actual_host" + } + + with patch.dict(os.environ, env_vars): + result = _substitute_env_vars(data) + + assert result["host"] == "actual_host" + assert result["missing"] == "default_value" + + def test_substitute_env_vars_no_substitution(self): + """Test that non-template strings are unchanged""" + data = { + "host": "localhost", + "port": 8080, + "path": "/tmp/test" + } + + result = _substitute_env_vars(data) + assert result == data + + +class TestDictToConfig: + """Test dictionary to config conversion""" + + def test_dict_to_config_full(self): + """Test converting full config dictionary""" + data = { + "server": { + "host": "127.0.0.1", + "port": 8080, + "log_level": "DEBUG" + }, + "redis": { + "host": "redis-server", + "port": 6379, + "password": "secret", + "db": 1 + }, + "joern": { + "binary_path": "/usr/bin/joern", + "memory_limit": "8g" + }, + "sessions": { + "ttl": 7200, + "idle_timeout": 3600, + "max_concurrent": 20 + }, + "cpg": { + "generation_timeout": 1200, + "max_repo_size_mb": 1000 + }, + "query": { + "timeout": 60, + "cache_enabled": False, + "cache_ttl": 600 + }, + "storage": { + "workspace_root": "/tmp/custom", + "cleanup_on_shutdown": False + } + } + + config = _dict_to_config(data) + + assert config.server.host == "127.0.0.1" + assert config.server.port == 8080 + assert config.server.log_level == "DEBUG" + assert config.redis.host == "redis-server" + assert config.redis.port == 6379 + assert config.redis.password == "secret" + assert config.redis.db == 1 + assert config.joern.binary_path == "/usr/bin/joern" + assert config.joern.memory_limit == "8g" + assert config.sessions.ttl == 7200 + assert config.sessions.idle_timeout == 3600 + assert config.sessions.max_concurrent == 20 + assert config.cpg.generation_timeout == 1200 + assert config.cpg.max_repo_size_mb == 1000 + assert config.query.timeout == 60 + assert config.query.cache_enabled is False + assert config.query.cache_ttl == 600 + assert config.storage.workspace_root == "/tmp/custom" + assert config.storage.cleanup_on_shutdown is False + + def test_dict_to_config_partial(self): + """Test converting partial config dictionary""" + data = { + "server": { + "port": 9000 + }, + "redis": { + "host": "custom-redis" + } + } + + config = _dict_to_config(data) + + # Specified values + assert config.server.port == 9000 + assert config.redis.host == "custom-redis" + + # Default values + assert config.server.host == "0.0.0.0" + assert config.server.log_level == "INFO" + assert config.redis.port == 6379 + assert config.redis.password is None + assert config.redis.db == 0 + + def test_dict_to_config_empty(self): + """Test converting empty config dictionary""" + config = _dict_to_config({}) + + # All default values + assert config.server.host == "0.0.0.0" + assert config.server.port == 4242 + assert config.redis.host == "localhost" + assert config.redis.port == 6379 + + def test_dict_to_config_type_conversions(self): + """Test type conversions in config""" + data = { + "server": { + "port": "9000", # String to int + "log_level": "INFO" + }, + "redis": { + "port": "6380", # String to int + "db": "2", # String to int + "decode_responses": "false" # String to bool + }, + "query": { + "cache_enabled": "true", # String to bool + "timeout": "45" # String to int + }, + "storage": { + "cleanup_on_shutdown": "false" # String to bool + } + } + + config = _dict_to_config(data) + + assert config.server.port == 9000 + assert config.redis.port == 6380 + assert config.redis.db == 2 + assert config.redis.decode_responses is False + assert config.query.cache_enabled is True + assert config.query.timeout == 45 + assert config.storage.cleanup_on_shutdown is False \ No newline at end of file diff --git a/tests/test_cpg_generator.py b/tests/test_cpg_generator.py new file mode 100644 index 0000000..c06aa12 --- /dev/null +++ b/tests/test_cpg_generator.py @@ -0,0 +1,385 @@ +""" +Tests for CPG generator +""" +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch, call +from src.services.cpg_generator import CPGGenerator +from src.models import CPGConfig, SessionStatus +from src.services.session_manager import SessionManager +from src.exceptions import CPGGenerationError + + +class TestCPGGenerator: + """Test CPG generator functionality""" + + @pytest.fixture + def cpg_config(self): + """CPG configuration fixture""" + return CPGConfig( + generation_timeout=600, + max_repo_size_mb=500, + supported_languages=["java", "python", "c", "cpp"] + ) + + @pytest.fixture + def mock_session_manager(self): + """Mock session manager fixture""" + return AsyncMock(spec=SessionManager) + + @pytest.fixture + def cpg_generator(self, cpg_config, mock_session_manager): + """CPG generator fixture""" + generator = CPGGenerator(cpg_config, mock_session_manager) + return generator + + @pytest.mark.asyncio + async def test_initialize_success(self, cpg_generator): + """Test successful Docker client initialization""" + mock_docker_client = MagicMock() + mock_docker_client.ping = MagicMock() + + with patch('docker.from_env', return_value=mock_docker_client): + await cpg_generator.initialize() + + assert cpg_generator.docker_client == mock_docker_client + + @pytest.mark.asyncio + async def test_initialize_failure(self, cpg_generator): + """Test Docker client initialization failure""" + with patch('docker.from_env', side_effect=Exception("Docker not available")): + with pytest.raises(CPGGenerationError, match="Docker initialization failed"): + await cpg_generator.initialize() + + @pytest.mark.asyncio + async def test_create_session_container(self, cpg_generator): + """Test creating Docker container for session""" + mock_container = MagicMock() + mock_container.id = "container-123" + + mock_docker_client = MagicMock() + mock_docker_client.containers.run = MagicMock(return_value=mock_container) + cpg_generator.docker_client = mock_docker_client + + container_id = await cpg_generator.create_session_container( + session_id="session-123", + workspace_path="/tmp/workspace" + ) + + assert container_id == "container-123" + assert cpg_generator.session_containers["session-123"] == "container-123" + + # Verify container creation call + mock_docker_client.containers.run.assert_called_once() + call_kwargs = mock_docker_client.containers.run.call_args[1] + + assert call_kwargs["image"] == "joern:latest" + assert call_kwargs["name"] == "joern-session-session-123" + assert call_kwargs["detach"] is True + assert "/tmp/workspace" in str(call_kwargs["volumes"]) + + @pytest.mark.asyncio + async def test_create_session_container_failure(self, cpg_generator): + """Test container creation failure""" + mock_docker_client = MagicMock() + mock_docker_client.containers.run = MagicMock(side_effect=Exception("Container creation failed")) + cpg_generator.docker_client = mock_docker_client + + with pytest.raises(CPGGenerationError, match="Container creation failed"): + await cpg_generator.create_session_container( + session_id="session-123", + workspace_path="/tmp/workspace" + ) + + @pytest.mark.asyncio + async def test_generate_cpg_java(self, cpg_generator, mock_session_manager): + """Test CPG generation for Java project""" + # Setup mocks + mock_container = MagicMock() + mock_container.exec_run = MagicMock(return_value=MagicMock(output=b"CPG generated successfully", exit_code=0)) + + mock_docker_client = MagicMock() + mock_docker_client.containers.get = MagicMock(return_value=mock_container) + cpg_generator.docker_client = mock_docker_client + cpg_generator.session_containers["session-123"] = "container-123" + + # Mock the helper methods + with patch.object(cpg_generator, '_find_joern_executable', return_value="javasrc2cpg"), \ + patch.object(cpg_generator, '_exec_command_async', return_value=""), \ + patch.object(cpg_generator, '_validate_cpg_async', return_value=True): + + mock_session_manager.update_status = AsyncMock() + mock_session_manager.update_session = AsyncMock() + + result = await cpg_generator.generate_cpg( + session_id="session-123", + source_path="/workspace/src", + language="java" + ) + + assert result == "/playground/cpgs/session-123.cpg" + mock_session_manager.update_status.assert_any_call("session-123", SessionStatus.GENERATING.value) + mock_session_manager.update_session.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_cpg_python(self, cpg_generator, mock_session_manager): + """Test CPG generation for Python project""" + mock_container = MagicMock() + mock_container.exec_run = MagicMock(return_value=MagicMock(output=b"CPG generated successfully", exit_code=0)) + + mock_docker_client = MagicMock() + mock_docker_client.containers.get = MagicMock(return_value=mock_container) + cpg_generator.docker_client = mock_docker_client + cpg_generator.session_containers["session-456"] = "container-456" + + with patch.object(cpg_generator, '_find_joern_executable', return_value="pysrc2cpg"), \ + patch.object(cpg_generator, '_exec_command_async', return_value=""), \ + patch.object(cpg_generator, '_validate_cpg_async', return_value=True): + + mock_session_manager.update_status = AsyncMock() + mock_session_manager.update_session = AsyncMock() + + result = await cpg_generator.generate_cpg( + session_id="session-456", + source_path="/workspace/src", + language="python" + ) + + assert result == "/playground/cpgs/session-456.cpg" + + @pytest.mark.asyncio + async def test_generate_cpg_timeout(self, cpg_generator, mock_session_manager): + """Test CPG generation timeout""" + mock_container = MagicMock() + mock_docker_client = MagicMock() + mock_docker_client.containers.get = MagicMock(return_value=mock_container) + cpg_generator.docker_client = mock_docker_client + cpg_generator.session_containers["session-123"] = "container-123" + + with patch.object(cpg_generator, '_find_joern_executable', return_value="javasrc2cpg"), \ + patch.object(cpg_generator, '_exec_command_async', side_effect=asyncio.TimeoutError()): + + mock_session_manager.update_status = AsyncMock() + + with pytest.raises(CPGGenerationError, match="CPG generation timed out"): + await cpg_generator.generate_cpg( + session_id="session-123", + source_path="/workspace/src", + language="java" + ) + + @pytest.mark.asyncio + async def test_generate_cpg_validation_failure(self, cpg_generator, mock_session_manager): + """Test CPG generation with validation failure""" + mock_container = MagicMock() + mock_docker_client = MagicMock() + mock_docker_client.containers.get = MagicMock(return_value=mock_container) + cpg_generator.docker_client = mock_docker_client + cpg_generator.session_containers["session-123"] = "container-123" + + with patch.object(cpg_generator, '_find_joern_executable', return_value="javasrc2cpg"), \ + patch.object(cpg_generator, '_exec_command_async', return_value=""), \ + patch.object(cpg_generator, '_validate_cpg_async', return_value=False): + + mock_session_manager.update_status = AsyncMock() + + with pytest.raises(CPGGenerationError, match="CPG file was not created"): + await cpg_generator.generate_cpg( + session_id="session-123", + source_path="/workspace/src", + language="java" + ) + + def test_language_commands_mapping(self, cpg_generator): + """Test language to command mapping""" + expected_commands = { + "java": "javasrc2cpg", + "c": "c2cpg", + "cpp": "c2cpg", + "javascript": "jssrc2cpg", + "python": "pysrc2cpg", + "go": "gosrc2cpg", + "kotlin": "kotlin2cpg", + "csharp": "csharpsrc2cpg", + "ghidra": "ghidra2cpg", + "jimple": "jimple2cpg", + "php": "php2cpg", + "ruby": "rubysrc2cpg", + "swift": "swiftsrc2cpg", + } + + assert cpg_generator.LANGUAGE_COMMANDS == expected_commands + + @pytest.mark.asyncio + async def test_find_joern_executable_found(self, cpg_generator): + """Test finding Joern executable successfully""" + mock_container = MagicMock() + # Mock successful test for javasrc2cpg at the first path + mock_container.exec_run = MagicMock(side_effect=[ + MagicMock(exit_code=0), # First path succeeds + ]) + + result = await cpg_generator._find_joern_executable(mock_container, "javasrc2cpg") + + assert result == "/opt/joern/joern-cli/javasrc2cpg" + + @pytest.mark.asyncio + async def test_find_joern_executable_not_found(self, cpg_generator): + """Test finding Joern executable when not found""" + mock_container = MagicMock() + # Mock failed tests for all paths + mock_container.exec_run = MagicMock(return_value=MagicMock(exit_code=1)) + + result = await cpg_generator._find_joern_executable(mock_container, "javasrc2cpg") + + assert result == "javasrc2cpg" # Falls back to base command + + @pytest.mark.asyncio + async def test_validate_cpg_success(self, cpg_generator): + """Test successful CPG validation""" + mock_container = MagicMock() + mock_exec_result = MagicMock() + mock_exec_result.output = b"-rw-r--r-- 1 user user 1024 Jan 1 12:00 /playground/cpgs/session-123.cpg" + mock_container.exec_run = MagicMock(return_value=mock_exec_result) + + result = await cpg_generator._validate_cpg_async(mock_container, "/playground/cpgs/session-123.cpg") + + assert result is True + + @pytest.mark.asyncio + async def test_validate_cpg_failure(self, cpg_generator): + """Test CPG validation failure""" + mock_container = MagicMock() + mock_exec_result = MagicMock() + mock_exec_result.output = b"ls: cannot access '/playground/cpgs/session-123.cpg': No such file or directory" + mock_container.exec_run = MagicMock(return_value=mock_exec_result) + + result = await cpg_generator._validate_cpg_async(mock_container, "/playground/cpgs/session-123.cpg") + + assert result is False + + @pytest.mark.asyncio + async def test_get_container_id(self, cpg_generator): + """Test getting container ID for session""" + cpg_generator.session_containers["session-123"] = "container-456" + + result = await cpg_generator.get_container_id("session-123") + + assert result == "container-456" + + @pytest.mark.asyncio + async def test_get_container_id_not_found(self, cpg_generator): + """Test getting container ID for non-existent session""" + result = await cpg_generator.get_container_id("nonexistent") + + assert result is None + + def test_register_session_container(self, cpg_generator): + """Test registering externally created container""" + cpg_generator.register_session_container("session-123", "container-456") + + assert cpg_generator.session_containers["session-123"] == "container-456" + + @pytest.mark.asyncio + async def test_close_session(self, cpg_generator): + """Test closing session container""" + mock_container = MagicMock() + mock_docker_client = MagicMock() + mock_docker_client.containers.get = MagicMock(return_value=mock_container) + cpg_generator.docker_client = mock_docker_client + cpg_generator.session_containers["session-123"] = "container-456" + + await cpg_generator.close_session("session-123") + + mock_docker_client.containers.get.assert_called_once_with("container-456") + mock_container.stop.assert_called_once() + mock_container.remove.assert_called_once() + assert "session-123" not in cpg_generator.session_containers + + @pytest.mark.asyncio + async def test_close_session_error(self, cpg_generator): + """Test closing session with container error""" + mock_docker_client = MagicMock() + mock_docker_client.containers.get = MagicMock(side_effect=Exception("Container not found")) + cpg_generator.docker_client = mock_docker_client + cpg_generator.session_containers["session-123"] = "container-456" + + # Should not raise exception + await cpg_generator.close_session("session-123") + + assert "session-123" not in cpg_generator.session_containers + + @pytest.mark.asyncio + async def test_cleanup(self, cpg_generator): + """Test cleanup of all containers""" + cpg_generator.session_containers = { + "session1": "container1", + "session2": "container2" + } + + with patch.object(cpg_generator, 'close_session', new_callable=AsyncMock) as mock_close: + await cpg_generator.cleanup() + + assert mock_close.call_count == 2 + mock_close.assert_any_call("session1") + mock_close.assert_any_call("session2") + + @pytest.mark.asyncio + async def test_stream_logs(self, cpg_generator): + """Test streaming logs during CPG generation""" + mock_container = MagicMock() + mock_exec_result = MagicMock() + mock_exec_result.output = [b"Starting CPG generation...\n", b"Processing files...\n", b"CPG created successfully\n"] + mock_container.exec_run = MagicMock(return_value=mock_exec_result) + + mock_docker_client = MagicMock() + mock_docker_client.containers.get = MagicMock(return_value=mock_container) + cpg_generator.docker_client = mock_docker_client + cpg_generator.session_containers["session-123"] = "container-123" + + with patch.object(cpg_generator, '_find_joern_executable', return_value="javasrc2cpg"): + logs = [] + async for log in cpg_generator.stream_logs( + session_id="session-123", + source_path="/workspace/src", + language="java", + output_path="/output.cpg" + ): + logs.append(log) + + assert len(logs) == 3 + assert "Starting CPG generation..." in logs[0] + + @pytest.mark.asyncio + async def test_stream_logs_no_container(self, cpg_generator): + """Test streaming logs when no container exists""" + logs = [] + async for log in cpg_generator.stream_logs( + session_id="session-123", + source_path="/workspace/src", + language="java", + output_path="/output.cpg" + ): + logs.append(log) + + assert len(logs) == 1 + assert "ERROR: No container found" in logs[0] + + @pytest.mark.asyncio + async def test_stream_logs_unsupported_language(self, cpg_generator): + """Test streaming logs with unsupported language""" + cpg_generator.session_containers["session-123"] = "container-123" + # Mock docker client to avoid NoneType error + cpg_generator.docker_client = MagicMock() + + logs = [] + async for log in cpg_generator.stream_logs( + session_id="session-123", + source_path="/workspace/src", + language="unsupported", + output_path="/output.cpg" + ): + logs.append(log) + + assert len(logs) == 1 + assert "ERROR: Unsupported language: unsupported" in logs[0] \ No newline at end of file diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..89e3e21 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,107 @@ +""" +Tests for custom exceptions +""" +import pytest +from src.exceptions import ( + JoernMCPError, + SessionNotFoundError, + SessionNotReadyError, + CPGGenerationError, + QueryExecutionError, + DockerError, + ResourceLimitError, + ValidationError, + GitOperationError +) + + +class TestExceptions: + """Test custom exception classes""" + + def test_base_exception(self): + """Test base JoernMCPError""" + error = JoernMCPError("Test error") + assert str(error) == "Test error" + assert isinstance(error, Exception) + + def test_session_not_found_error(self): + """Test SessionNotFoundError""" + error = SessionNotFoundError("Session 123 not found") + assert str(error) == "Session 123 not found" + assert isinstance(error, JoernMCPError) + + def test_session_not_ready_error(self): + """Test SessionNotReadyError""" + error = SessionNotReadyError("Session not ready") + assert str(error) == "Session not ready" + assert isinstance(error, JoernMCPError) + + def test_cpg_generation_error(self): + """Test CPGGenerationError""" + error = CPGGenerationError("CPG generation failed") + assert str(error) == "CPG generation failed" + assert isinstance(error, JoernMCPError) + + def test_query_execution_error(self): + """Test QueryExecutionError""" + error = QueryExecutionError("Query execution failed") + assert str(error) == "Query execution failed" + assert isinstance(error, JoernMCPError) + + def test_docker_error(self): + """Test DockerError""" + error = DockerError("Docker operation failed") + assert str(error) == "Docker operation failed" + assert isinstance(error, JoernMCPError) + + def test_resource_limit_error(self): + """Test ResourceLimitError""" + error = ResourceLimitError("Resource limit exceeded") + assert str(error) == "Resource limit exceeded" + assert isinstance(error, JoernMCPError) + + def test_validation_error(self): + """Test ValidationError""" + error = ValidationError("Invalid input") + assert str(error) == "Invalid input" + assert isinstance(error, JoernMCPError) + + def test_git_operation_error(self): + """Test GitOperationError""" + error = GitOperationError("Git operation failed") + assert str(error) == "Git operation failed" + assert isinstance(error, JoernMCPError) + + def test_exception_hierarchy(self): + """Test that all exceptions inherit from JoernMCPError""" + exceptions = [ + SessionNotFoundError("test"), + SessionNotReadyError("test"), + CPGGenerationError("test"), + QueryExecutionError("test"), + DockerError("test"), + ResourceLimitError("test"), + ValidationError("test"), + GitOperationError("test") + ] + + for exc in exceptions: + assert isinstance(exc, JoernMCPError) + assert isinstance(exc, Exception) + + def test_exception_with_custom_message(self): + """Test exceptions with custom messages""" + test_cases = [ + (SessionNotFoundError, "Session abc-123 not found"), + (SessionNotReadyError, "Session is still generating"), + (CPGGenerationError, "Failed to generate CPG for Java project"), + (QueryExecutionError, "Invalid CPGQL syntax"), + (DockerError, "Cannot connect to Docker daemon"), + (ResourceLimitError, "Maximum concurrent sessions reached"), + (ValidationError, "Unsupported language: rust"), + (GitOperationError, "Repository not found") + ] + + for exc_class, message in test_cases: + error = exc_class(message) + assert str(error) == message \ No newline at end of file diff --git a/tests/test_logging.py b/tests/test_logging.py new file mode 100644 index 0000000..fbb1eaa --- /dev/null +++ b/tests/test_logging.py @@ -0,0 +1,206 @@ +""" +Tests for logging configuration +""" +import logging +import sys +from unittest.mock import patch, MagicMock +from src.utils.logging import setup_logging, get_logger + + +class TestSetupLogging: + """Test logging setup functionality""" + + def test_setup_logging_default_level(self): + """Test setting up logging with default INFO level""" + with patch('logging.getLogger') as mock_get_logger, \ + patch('sys.stdout', create=True) as mock_stdout: + + mock_root_logger = MagicMock() + mock_docker_logger = MagicMock() + mock_urllib_logger = MagicMock() + mock_git_logger = MagicMock() + + def get_logger_side_effect(name=None): + if name == 'docker': + return mock_docker_logger + elif name == 'urllib3': + return mock_urllib_logger + elif name == 'git': + return mock_git_logger + else: + return mock_root_logger + + mock_get_logger.side_effect = get_logger_side_effect + + setup_logging() + + # Verify root logger configuration + mock_root_logger.setLevel.assert_called_once_with(logging.INFO) + assert mock_root_logger.addHandler.called + + # Verify handler was added + handler_call = mock_root_logger.addHandler.call_args[0][0] + assert isinstance(handler_call, logging.StreamHandler) + assert handler_call.level == logging.INFO + + # Verify library loggers are set to WARNING + mock_docker_logger.setLevel.assert_called_once_with(logging.WARNING) + mock_urllib_logger.setLevel.assert_called_once_with(logging.WARNING) + mock_git_logger.setLevel.assert_called_once_with(logging.WARNING) + + def test_setup_logging_custom_level(self): + """Test setting up logging with custom log level""" + with patch('logging.getLogger') as mock_get_logger, \ + patch('sys.stdout', create=True) as mock_stdout: + + mock_root_logger = MagicMock() + mock_docker_logger = MagicMock() + mock_urllib_logger = MagicMock() + mock_git_logger = MagicMock() + + def get_logger_side_effect(name=None): + if name == 'docker': + return mock_docker_logger + elif name == 'urllib3': + return mock_urllib_logger + elif name == 'git': + return mock_git_logger + else: + return mock_root_logger + + mock_get_logger.side_effect = get_logger_side_effect + + setup_logging("DEBUG") + + # Verify root logger configuration + mock_root_logger.setLevel.assert_called_once_with(logging.DEBUG) + + def test_setup_logging_invalid_level(self): + """Test setting up logging with invalid level defaults to INFO""" + with patch('logging.getLogger') as mock_get_logger, \ + patch('sys.stdout', create=True) as mock_stdout: + + mock_root_logger = MagicMock() + mock_docker_logger = MagicMock() + mock_urllib_logger = MagicMock() + mock_git_logger = MagicMock() + + def get_logger_side_effect(name=None): + if name == 'docker': + return mock_docker_logger + elif name == 'urllib3': + return mock_urllib_logger + elif name == 'git': + return mock_git_logger + else: + return mock_root_logger + + mock_get_logger.side_effect = get_logger_side_effect + + setup_logging("INVALID") + + # Should default to INFO for invalid level + mock_root_logger.setLevel.assert_called_once_with(logging.INFO) + + def test_setup_logging_removes_existing_handlers(self): + """Test that existing handlers are removed before setup""" + with patch('logging.getLogger') as mock_get_logger, \ + patch('sys.stdout', create=True) as mock_stdout: + + mock_root_logger = MagicMock() + mock_root_logger.handlers = [MagicMock(), MagicMock()] # Existing handlers + mock_get_logger.return_value = mock_root_logger + + setup_logging() + + # Verify existing handlers were removed + assert mock_root_logger.removeHandler.call_count == 2 + + def test_setup_logging_library_noise_reduction(self): + """Test that noisy library loggers are configured""" + with patch('logging.getLogger') as mock_get_logger, \ + patch('sys.stdout', create=True) as mock_stdout: + + mock_root_logger = MagicMock() + mock_docker_logger = MagicMock() + mock_urllib_logger = MagicMock() + mock_git_logger = MagicMock() + + def get_logger_side_effect(name=None): + if name == 'docker': + return mock_docker_logger + elif name == 'urllib3': + return mock_urllib_logger + elif name == 'git': + return mock_git_logger + else: + return mock_root_logger + + mock_get_logger.side_effect = get_logger_side_effect + + setup_logging() + + # Verify library loggers are set to WARNING + mock_docker_logger.setLevel.assert_called_once_with(logging.WARNING) + mock_urllib_logger.setLevel.assert_called_once_with(logging.WARNING) + mock_git_logger.setLevel.assert_called_once_with(logging.WARNING) + + def test_setup_logging_formatter(self): + """Test that handlers get proper formatter""" + with patch('logging.getLogger') as mock_get_logger, \ + patch('sys.stdout', create=True) as mock_stdout: + + mock_root_logger = MagicMock() + mock_get_logger.return_value = mock_root_logger + + setup_logging() + + # Get the handler that was added + handler = mock_root_logger.addHandler.call_args[0][0] + + # Verify formatter was set + assert handler.formatter is not None + assert isinstance(handler.formatter, logging.Formatter) + + # Test the format string + format_str = handler.formatter._fmt + expected_parts = [ + '%(asctime)s', + '%(name)s', + '%(levelname)s', + '%(message)s' + ] + + for part in expected_parts: + assert part in format_str + + +class TestGetLogger: + """Test logger retrieval""" + + def test_get_logger(self): + """Test getting a logger instance""" + with patch('logging.getLogger') as mock_get_logger: + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + logger = get_logger("test.module") + + mock_get_logger.assert_called_once_with("test.module") + assert logger == mock_logger + + def test_get_logger_different_names(self): + """Test getting loggers with different names""" + with patch('logging.getLogger') as mock_get_logger: + mock_logger1 = MagicMock() + mock_logger2 = MagicMock() + mock_get_logger.side_effect = [mock_logger1, mock_logger2] + + logger1 = get_logger("module1") + logger2 = get_logger("module2") + + assert mock_get_logger.call_count == 2 + mock_get_logger.assert_any_call("module1") + mock_get_logger.assert_any_call("module2") + assert logger1 == mock_logger1 + assert logger2 == mock_logger2 \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..e0eb1bc --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,117 @@ +""" +Tests for main module +""" +import pytest +from unittest.mock import patch, AsyncMock +import sys +from pathlib import Path + +# Add the project root to the path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +# Import main module +import main +lifespan = main.lifespan + + +class TestLifespan: + """Test FastMCP lifespan management""" + + @pytest.mark.asyncio + async def test_lifespan_success(self): + """Test successful lifespan startup and shutdown""" + mock_mcp = AsyncMock() + + # Mock all the services and dependencies + with patch('main.load_config') as mock_load_config, \ + patch('main.RedisClient') as mock_redis_client_class, \ + patch('main.SessionManager') as mock_session_manager_class, \ + patch('main.GitManager') as mock_git_manager_class, \ + patch('main.CPGGenerator') as mock_cpg_generator_class, \ + patch('main.DockerOrchestrator') as mock_docker_orch_class, \ + patch('main.QueryExecutor') as mock_query_executor_class, \ + patch('main.setup_logging') as mock_setup_logging, \ + patch('main.logger') as mock_logger, \ + patch('os.makedirs') as mock_makedirs: + + # Setup mocks + mock_config = AsyncMock() + mock_config.server.log_level = "INFO" + mock_config.storage.workspace_root = "/tmp/workspace" + mock_config.redis = AsyncMock() + mock_config.sessions = AsyncMock() + mock_config.cpg = AsyncMock() + mock_config.query = AsyncMock() + mock_config.joern = AsyncMock() + + mock_load_config.return_value = mock_config + + mock_redis_client = AsyncMock() + mock_redis_client_class.return_value = mock_redis_client + + mock_session_manager = AsyncMock() + mock_session_manager_class.return_value = mock_session_manager + + mock_git_manager = AsyncMock() + mock_git_manager_class.return_value = mock_git_manager + + mock_cpg_generator = AsyncMock() + mock_cpg_generator_class.return_value = mock_cpg_generator + + mock_docker_orch = AsyncMock() + mock_docker_orch_class.return_value = mock_docker_orch + + mock_query_executor = AsyncMock() + mock_query_executor_class.return_value = mock_query_executor + + # Test lifespan context manager + async with lifespan(mock_mcp): + # Verify initialization calls + mock_load_config.assert_called_with("config.yaml") + mock_setup_logging.assert_called_with("INFO") + mock_makedirs.assert_called() + mock_redis_client.connect.assert_called_once() + mock_session_manager.set_docker_cleanup_callback.assert_called_once() + mock_cpg_generator.initialize.assert_called_once() + mock_query_executor.initialize.assert_called_once() + + # Verify shutdown calls + mock_query_executor.cleanup.assert_called_once() + mock_docker_orch.cleanup.assert_called_once() + mock_redis_client.close.assert_called_once() + + @pytest.mark.asyncio + async def test_lifespan_initialization_failure(self): + """Test lifespan with initialization failure""" + mock_mcp = AsyncMock() + + with patch('main.load_config', side_effect=Exception("Config load failed")), \ + patch('main.logger') as mock_logger: + + with pytest.raises(Exception, match="Config load failed"): + async with lifespan(mock_mcp): + pass + + @pytest.mark.asyncio + async def test_lifespan_redis_connection_failure(self): + """Test lifespan with Redis connection failure""" + mock_mcp = AsyncMock() + + with patch('main.load_config') as mock_load_config, \ + patch('main.RedisClient') as mock_redis_client_class, \ + patch('main.setup_logging'), \ + patch('os.makedirs'), \ + patch('main.logger') as mock_logger: + + mock_config = AsyncMock() + mock_config.server.log_level = "INFO" + mock_config.storage.workspace_root = "/tmp/workspace" + mock_load_config.return_value = mock_config + + mock_redis_client = AsyncMock() + mock_redis_client.connect = AsyncMock(side_effect=Exception("Redis connection failed")) + mock_redis_client_class.return_value = mock_redis_client + + with pytest.raises(Exception, match="Redis connection failed"): + async with lifespan(mock_mcp): + pass \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..adac691 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,271 @@ +""" +Tests for data models +""" +import pytest +from datetime import datetime +from src.models import ( + Session, + SessionStatus, + SourceType, + QueryResult, + Config, + ServerConfig, + RedisConfig, + SessionConfig, + CPGConfig, + QueryConfig, + StorageConfig, + JoernConfig +) + + +class TestSession: + """Test Session model""" + + def test_session_creation(self): + """Test basic session creation""" + session = Session( + id="test-id", + source_type="github", + source_path="https://github.com/user/repo", + language="python" + ) + + assert session.id == "test-id" + assert session.source_type == "github" + assert session.source_path == "https://github.com/user/repo" + assert session.language == "python" + assert session.status == SessionStatus.INITIALIZING.value + assert session.container_id is None + assert session.cpg_path is None + assert isinstance(session.created_at, datetime) + assert isinstance(session.last_accessed, datetime) + assert session.error_message is None + assert session.metadata == {} + + def test_session_to_dict(self): + """Test session serialization""" + session = Session( + id="test-id", + source_type="local", + source_path="/path/to/code", + language="java", + status="ready", + container_id="container-123", + cpg_path="/path/to/cpg.bin", + error_message="Test error" + ) + + data = session.to_dict() + + assert data["id"] == "test-id" + assert data["source_type"] == "local" + assert data["source_path"] == "/path/to/code" + assert data["language"] == "java" + assert data["status"] == "ready" + assert data["container_id"] == "container-123" + assert data["cpg_path"] == "/path/to/cpg.bin" + assert data["error_message"] == "Test error" + assert "created_at" in data + assert "last_accessed" in data + + def test_session_from_dict(self): + """Test session deserialization""" + data = { + "id": "test-id", + "source_type": "github", + "source_path": "https://github.com/user/repo", + "language": "python", + "status": "ready", + "container_id": "container-123", + "cpg_path": "/path/to/cpg.bin", + "created_at": "2023-01-01T12:00:00", + "last_accessed": "2023-01-01T12:30:00", + "error_message": None, + "metadata": {"key": "value"} + } + + session = Session.from_dict(data) + + assert session.id == "test-id" + assert session.source_type == "github" + assert session.source_path == "https://github.com/user/repo" + assert session.language == "python" + assert session.status == "ready" + assert session.container_id == "container-123" + assert session.cpg_path == "/path/to/cpg.bin" + assert session.error_message is None + assert session.metadata == {"key": "value"} + assert isinstance(session.created_at, datetime) + assert isinstance(session.last_accessed, datetime) + + +class TestQueryResult: + """Test QueryResult model""" + + def test_query_result_creation(self): + """Test basic query result creation""" + result = QueryResult( + success=True, + data=[{"name": "test"}], + execution_time=1.5 + ) + + assert result.success is True + assert result.data == [{"name": "test"}] + assert result.error is None + assert result.execution_time == 1.5 + assert result.row_count == 0 + + def test_query_result_with_error(self): + """Test query result with error""" + result = QueryResult( + success=False, + error="Query failed", + execution_time=0.5 + ) + + assert result.success is False + assert result.data is None + assert result.error == "Query failed" + assert result.execution_time == 0.5 + + def test_query_result_to_dict(self): + """Test query result serialization""" + result = QueryResult( + success=True, + data=[{"name": "test"}], + execution_time=1.5, + row_count=1 + ) + + data = result.to_dict() + + assert data["success"] is True + assert data["data"] == [{"name": "test"}] + assert data["error"] is None + assert data["execution_time"] == 1.5 + assert data["row_count"] == 1 + + +class TestEnums: + """Test enumeration classes""" + + def test_session_status_values(self): + """Test SessionStatus enum values""" + assert SessionStatus.INITIALIZING.value == "initializing" + assert SessionStatus.GENERATING.value == "generating" + assert SessionStatus.READY.value == "ready" + assert SessionStatus.ERROR.value == "error" + + def test_source_type_values(self): + """Test SourceType enum values""" + assert SourceType.LOCAL.value == "local" + assert SourceType.GITHUB.value == "github" + + +class TestConfigModels: + """Test configuration models""" + + def test_server_config(self): + """Test ServerConfig creation""" + config = ServerConfig( + host="127.0.0.1", + port=8080, + log_level="DEBUG" + ) + + assert config.host == "127.0.0.1" + assert config.port == 8080 + assert config.log_level == "DEBUG" + + def test_redis_config(self): + """Test RedisConfig creation""" + config = RedisConfig( + host="localhost", + port=6379, + password="secret", + db=1 + ) + + assert config.host == "localhost" + assert config.port == 6379 + assert config.password == "secret" + assert config.db == 1 + assert config.decode_responses is True + + def test_session_config(self): + """Test SessionConfig creation""" + config = SessionConfig( + ttl=7200, + idle_timeout=3600, + max_concurrent=50 + ) + + assert config.ttl == 7200 + assert config.idle_timeout == 3600 + assert config.max_concurrent == 50 + + def test_cpg_config(self): + """Test CPGConfig creation""" + config = CPGConfig( + generation_timeout=1200, + max_repo_size_mb=1000 + ) + + assert config.generation_timeout == 1200 + assert config.max_repo_size_mb == 1000 + assert "java" in config.supported_languages + assert "python" in config.supported_languages + + def test_query_config(self): + """Test QueryConfig creation""" + config = QueryConfig( + timeout=60, + cache_enabled=False, + cache_ttl=600 + ) + + assert config.timeout == 60 + assert config.cache_enabled is False + assert config.cache_ttl == 600 + + def test_storage_config(self): + """Test StorageConfig creation""" + config = StorageConfig( + workspace_root="/tmp/test", + cleanup_on_shutdown=False + ) + + assert config.workspace_root == "/tmp/test" + assert config.cleanup_on_shutdown is False + + def test_joern_config(self): + """Test JoernConfig creation""" + config = JoernConfig( + binary_path="/usr/local/bin/joern", + memory_limit="8g" + ) + + assert config.binary_path == "/usr/local/bin/joern" + assert config.memory_limit == "8g" + + def test_config_composition(self): + """Test Config composition""" + config = Config( + server=ServerConfig(host="0.0.0.0", port=4242), + redis=RedisConfig(host="redis", port=6379), + joern=JoernConfig(binary_path="joern"), + sessions=SessionConfig(ttl=3600), + cpg=CPGConfig(generation_timeout=600), + query=QueryConfig(timeout=30), + storage=StorageConfig(workspace_root="/tmp/joern") + ) + + assert config.server.host == "0.0.0.0" + assert config.redis.host == "redis" + assert config.joern.binary_path == "joern" + assert config.sessions.ttl == 3600 + assert config.cpg.generation_timeout == 600 + assert config.query.timeout == 30 + assert config.storage.workspace_root == "/tmp/joern" \ No newline at end of file diff --git a/tests/test_redis_client.py b/tests/test_redis_client.py new file mode 100644 index 0000000..ed95d02 --- /dev/null +++ b/tests/test_redis_client.py @@ -0,0 +1,257 @@ +""" +Tests for Redis client wrapper +""" +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from src.utils.redis_client import RedisClient +from src.models import RedisConfig, Session +from src.exceptions import ValidationError + + +class TestRedisClient: + """Test Redis client functionality""" + + @pytest.fixture + def redis_config(self): + """Redis configuration fixture""" + return RedisConfig( + host="localhost", + port=6379, + password=None, + db=0, + decode_responses=True + ) + + @pytest.fixture + def mock_redis(self): + """Mock Redis client""" + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock() + return mock_redis + + @pytest.fixture + def redis_client(self, redis_config, mock_redis): + """Redis client fixture""" + client = RedisClient(redis_config) + client.client = mock_redis + return client + + def test_init(self, redis_config): + """Test Redis client initialization""" + client = RedisClient(redis_config) + + assert client.config == redis_config + assert client.client is None + + @pytest.mark.asyncio + async def test_connect_success(self, redis_config, mock_redis): + """Test successful Redis connection""" + with patch('src.utils.redis_client.redis.from_url', return_value=mock_redis) as mock_from_url: + client = RedisClient(redis_config) + await client.connect() + + mock_from_url.assert_called_once_with( + "redis://localhost:6379/0", + password=None, + decode_responses=True + ) + mock_redis.ping.assert_called_once() + assert client.client == mock_redis + + @pytest.mark.asyncio + async def test_connect_failure(self, redis_config): + """Test Redis connection failure""" + with patch('src.utils.redis_client.redis.from_url', side_effect=Exception("Connection failed")): + client = RedisClient(redis_config) + + with pytest.raises(Exception, match="Connection failed"): + await client.connect() + + @pytest.mark.asyncio + async def test_close(self, redis_client, mock_redis): + """Test closing Redis connection""" + await redis_client.close() + + mock_redis.close.assert_called_once() + + @pytest.mark.asyncio + async def test_save_session(self, redis_client, mock_redis): + """Test saving session to Redis""" + session = Session( + id="test-session", + source_type="github", + source_path="https://github.com/user/repo", + language="python" + ) + + mock_redis.set = AsyncMock() + mock_redis.sadd = AsyncMock() + + await redis_client.save_session(session, ttl=3600) + + # Verify session data was saved + mock_redis.set.assert_called_once() + call_args = mock_redis.set.call_args + assert call_args[0][0] == "session:test-session" + assert "test-session" in call_args[0][1] # JSON data + assert call_args[1]["ex"] == 3600 + + # Verify session was added to active set + mock_redis.sadd.assert_called_once_with("sessions:active", "test-session") + + @pytest.mark.asyncio + async def test_get_session_found(self, redis_client, mock_redis): + """Test retrieving existing session""" + session_data = { + "id": "test-session", + "source_type": "local", + "source_path": "/path/to/code", + "language": "java", + "status": "ready", + "created_at": "2023-01-01T12:00:00", + "last_accessed": "2023-01-01T12:30:00" + } + + import json + mock_redis.get = AsyncMock(return_value=json.dumps(session_data)) + + session = await redis_client.get_session("test-session") + + assert session is not None + assert session.id == "test-session" + assert session.source_type == "local" + assert session.language == "java" + + @pytest.mark.asyncio + async def test_get_session_not_found(self, redis_client, mock_redis): + """Test retrieving non-existent session""" + mock_redis.get = AsyncMock(return_value=None) + + session = await redis_client.get_session("nonexistent") + + assert session is None + + @pytest.mark.asyncio + async def test_update_session(self, redis_client, mock_redis): + """Test updating session fields""" + # Mock existing session + session_data = { + "id": "test-session", + "source_type": "github", + "source_path": "https://github.com/user/repo", + "language": "python", + "status": "initializing", + "created_at": "2023-01-01T12:00:00", + "last_accessed": "2023-01-01T12:00:00" + } + + import json + mock_redis.get = AsyncMock(return_value=json.dumps(session_data)) + mock_redis.set = AsyncMock() + + await redis_client.update_session("test-session", {"status": "ready", "language": "java"}, ttl=3600) + + # Verify updated session was saved + mock_redis.set.assert_called_once() + call_args = mock_redis.set.call_args + saved_data = json.loads(call_args[0][1]) + assert saved_data["status"] == "ready" + assert saved_data["language"] == "java" + assert call_args[1]["ex"] == 3600 + + @pytest.mark.asyncio + async def test_delete_session(self, redis_client, mock_redis): + """Test deleting session from Redis""" + mock_redis.delete = AsyncMock(return_value=1) + mock_redis.srem = AsyncMock(return_value=1) + + await redis_client.delete_session("test-session") + + mock_redis.delete.assert_called_once_with("session:test-session") + mock_redis.srem.assert_called_once_with("sessions:active", "test-session") + + @pytest.mark.asyncio + async def test_list_sessions(self, redis_client, mock_redis): + """Test listing all active sessions""" + mock_redis.smembers = AsyncMock(return_value={"session1", "session2", "session3"}) + + sessions = await redis_client.list_sessions() + + assert set(sessions) == {"session1", "session2", "session3"} + mock_redis.smembers.assert_called_once_with("sessions:active") + + @pytest.mark.asyncio + async def test_touch_session(self, redis_client, mock_redis): + """Test refreshing session TTL""" + mock_redis.expire = AsyncMock(return_value=True) + + await redis_client.touch_session("test-session", ttl=3600) + + mock_redis.expire.assert_called_once_with("session:test-session", 3600) + + @pytest.mark.asyncio + async def test_set_container_mapping(self, redis_client, mock_redis): + """Test setting container mapping""" + mock_redis.set = AsyncMock() + + await redis_client.set_container_mapping("container-123", "session-456", ttl=3600) + + mock_redis.set.assert_called_once_with("container:container-123", "session-456", ex=3600) + + @pytest.mark.asyncio + async def test_get_session_by_container(self, redis_client, mock_redis): + """Test getting session by container ID""" + mock_redis.get = AsyncMock(return_value="session-456") + + session_id = await redis_client.get_session_by_container("container-123") + + assert session_id == "session-456" + mock_redis.get.assert_called_once_with("container:container-123") + + @pytest.mark.asyncio + async def test_delete_container_mapping(self, redis_client, mock_redis): + """Test deleting container mapping""" + mock_redis.delete = AsyncMock(return_value=1) + + await redis_client.delete_container_mapping("container-123") + + mock_redis.delete.assert_called_once_with("container:container-123") + + @pytest.mark.asyncio + async def test_cache_query_result(self, redis_client, mock_redis): + """Test caching query result""" + result = {"data": [{"name": "test"}], "row_count": 1} + + import json + mock_redis.set = AsyncMock() + + await redis_client.cache_query_result("session-123", "query-hash", result, ttl=300) + + mock_redis.set.assert_called_once() + call_args = mock_redis.set.call_args + assert call_args[0][0] == "query:session-123:query-hash" + assert json.loads(call_args[0][1]) == result + assert call_args[1]["ex"] == 300 + + @pytest.mark.asyncio + async def test_get_cached_query(self, redis_client, mock_redis): + """Test retrieving cached query result""" + cached_result = {"data": [{"name": "test"}], "row_count": 1} + + import json + mock_redis.get = AsyncMock(return_value=json.dumps(cached_result)) + + result = await redis_client.get_cached_query("session-123", "query-hash") + + assert result == cached_result + mock_redis.get.assert_called_once_with("query:session-123:query-hash") + + @pytest.mark.asyncio + async def test_get_cached_query_not_found(self, redis_client, mock_redis): + """Test retrieving non-existent cached query""" + mock_redis.get = AsyncMock(return_value=None) + + result = await redis_client.get_cached_query("session-123", "query-hash") + + assert result is None \ No newline at end of file diff --git a/tests/test_session_manager.py b/tests/test_session_manager.py new file mode 100644 index 0000000..4b42607 --- /dev/null +++ b/tests/test_session_manager.py @@ -0,0 +1,283 @@ +""" +Tests for session manager +""" +import pytest +import asyncio +from datetime import datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch, ANY +from src.services.session_manager import SessionManager +from src.models import Session, SessionStatus, SessionConfig +from src.utils.redis_client import RedisClient +from src.exceptions import SessionNotFoundError, ResourceLimitError + + +class TestSessionManager: + """Test session manager functionality""" + + @pytest.fixture + def session_config(self): + """Session configuration fixture""" + return SessionConfig( + ttl=3600, + idle_timeout=1800, + max_concurrent=10 + ) + + @pytest.fixture + def mock_redis_client(self): + """Mock Redis client fixture""" + mock_client = AsyncMock(spec=RedisClient) + return mock_client + + @pytest.fixture + def session_manager(self, mock_redis_client, session_config): + """Session manager fixture""" + manager = SessionManager(mock_redis_client, session_config) + return manager + + @pytest.mark.asyncio + async def test_create_session_success(self, session_manager, mock_redis_client): + """Test successful session creation""" + mock_redis_client.list_sessions = AsyncMock(return_value=[]) + mock_redis_client.save_session = AsyncMock() + + session = await session_manager.create_session( + source_type="github", + source_path="https://github.com/user/repo", + language="python", + options={"branch": "main"} + ) + + assert isinstance(session, Session) + assert session.source_type == "github" + assert session.source_path == "https://github.com/user/repo" + assert session.language == "python" + assert session.status == SessionStatus.INITIALIZING.value + assert session.metadata == {"branch": "main"} + + # Verify Redis calls + mock_redis_client.list_sessions.assert_called_once() + mock_redis_client.save_session.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_concurrent_limit_reached(self, session_manager, mock_redis_client): + """Test session creation when concurrent limit is reached""" + # Mock existing sessions at limit + existing_sessions = [f"session-{i}" for i in range(10)] + mock_redis_client.list_sessions = AsyncMock(return_value=existing_sessions) + mock_redis_client.get_session = AsyncMock(side_effect=lambda sid: AsyncMock() if sid in existing_sessions else None) + mock_redis_client.save_session = AsyncMock() + + # Mock cleanup of oldest sessions + with patch.object(session_manager, '_cleanup_oldest_sessions', new_callable=AsyncMock) as mock_cleanup: + session = await session_manager.create_session( + source_type="local", + source_path="/path/to/code", + language="java", + options={} + ) + + # Verify cleanup was called + mock_cleanup.assert_called_once_with(10) + + @pytest.mark.asyncio + async def test_create_session_exception_handling(self, session_manager, mock_redis_client): + """Test session creation with exception handling""" + mock_redis_client.list_sessions = AsyncMock(side_effect=Exception("Redis error")) + + with pytest.raises(Exception, match="Redis error"): + await session_manager.create_session( + source_type="github", + source_path="https://github.com/user/repo", + language="python", + options={} + ) + + @pytest.mark.asyncio + async def test_get_session_found(self, session_manager, mock_redis_client): + """Test retrieving existing session""" + mock_session = Session( + id="test-session", + source_type="github", + source_path="https://github.com/user/repo", + language="python" + ) + mock_redis_client.get_session = AsyncMock(return_value=mock_session) + + result = await session_manager.get_session("test-session") + + assert result == mock_session + mock_redis_client.get_session.assert_called_once_with("test-session") + + @pytest.mark.asyncio + async def test_get_session_not_found(self, session_manager, mock_redis_client): + """Test retrieving non-existent session""" + mock_redis_client.get_session = AsyncMock(return_value=None) + + result = await session_manager.get_session("nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_update_session(self, session_manager, mock_redis_client): + """Test updating session fields""" + mock_redis_client.update_session = AsyncMock() + + await session_manager.update_session("test-session", status="ready", language="java") + + mock_redis_client.update_session.assert_called_once_with( + "test-session", + {"status": "ready", "language": "java", "last_accessed": ANY}, + 3600 + ) + + @pytest.mark.asyncio + async def test_update_status(self, session_manager, mock_redis_client): + """Test updating session status""" + mock_redis_client.update_session = AsyncMock() + + await session_manager.update_status("test-session", "ready", "Operation completed") + + expected_updates = { + "status": "ready", + "error_message": "Operation completed", + "last_accessed": ANY + } + + mock_redis_client.update_session.assert_called_once_with( + "test-session", + expected_updates, + 3600 + ) + + @pytest.mark.asyncio + async def test_list_sessions_no_filters(self, session_manager, mock_redis_client): + """Test listing all sessions without filters""" + mock_sessions = [ + Session(id="session1", source_type="github", language="python"), + Session(id="session2", source_type="local", language="java") + ] + + mock_redis_client.list_sessions = AsyncMock(return_value=["session1", "session2"]) + mock_redis_client.get_session = AsyncMock(side_effect=mock_sessions) + + result = await session_manager.list_sessions() + + assert len(result) == 2 + assert result[0].id == "session1" + assert result[1].id == "session2" + + @pytest.mark.asyncio + async def test_list_sessions_with_filters(self, session_manager, mock_redis_client): + """Test listing sessions with filters""" + mock_sessions = [ + Session(id="session1", source_type="github", language="python", status="ready"), + Session(id="session2", source_type="local", language="java", status="generating") + ] + + mock_redis_client.list_sessions = AsyncMock(return_value=["session1", "session2"]) + mock_redis_client.get_session = AsyncMock(side_effect=mock_sessions) + + # Filter by status + result = await session_manager.list_sessions({"status": "ready"}) + + assert len(result) == 1 + assert result[0].id == "session1" + + @pytest.mark.asyncio + async def test_touch_session(self, session_manager, mock_redis_client): + """Test refreshing session TTL""" + mock_redis_client.touch_session = AsyncMock() + mock_redis_client.update_session = AsyncMock() + + await session_manager.touch_session("test-session") + + mock_redis_client.touch_session.assert_called_once_with("test-session", 3600) + mock_redis_client.update_session.assert_called_once_with( + "test-session", + {"last_accessed": ANY}, + 3600 + ) + + @pytest.mark.asyncio + async def test_cleanup_session_success(self, session_manager, mock_redis_client): + """Test successful session cleanup""" + mock_session = Session( + id="test-session", + container_id="container-123", + source_type="github", + source_path="https://github.com/user/repo", + language="python" + ) + + mock_redis_client.get_session = AsyncMock(return_value=mock_session) + mock_redis_client.delete_container_mapping = AsyncMock() + mock_redis_client.delete_session = AsyncMock() + + await session_manager.cleanup_session("test-session") + + mock_redis_client.delete_container_mapping.assert_called_once_with("container-123") + mock_redis_client.delete_session.assert_called_once_with("test-session") + + @pytest.mark.asyncio + async def test_cleanup_session_not_found(self, session_manager, mock_redis_client): + """Test cleanup of non-existent session""" + mock_redis_client.get_session = AsyncMock(return_value=None) + + with pytest.raises(SessionNotFoundError): + await session_manager.cleanup_session("nonexistent") + + @pytest.mark.asyncio + async def test_cleanup_idle_sessions(self, session_manager, mock_redis_client): + """Test cleanup of idle sessions""" + # Create sessions with different last_accessed times + now = datetime.utcnow() + active_session = Session( + id="active", + last_accessed=now - timedelta(minutes=25) # Clearly not idle (25 min < 30 min) + ) + idle_session = Session( + id="idle", + last_accessed=now - timedelta(hours=1) # Idle + ) + + mock_redis_client.list_sessions = AsyncMock(return_value=["active", "idle"]) + mock_redis_client.get_session = AsyncMock(side_effect=[active_session, idle_session]) + + with patch.object(session_manager, 'cleanup_session', new_callable=AsyncMock) as mock_cleanup: + await session_manager.cleanup_idle_sessions() + + # Only idle session should be cleaned up + mock_cleanup.assert_called_once_with("idle") + + @pytest.mark.asyncio + async def test_cleanup_oldest_sessions(self, session_manager, mock_redis_client): + """Test cleanup of oldest sessions""" + # Create sessions with different creation times + base_time = datetime.utcnow() + sessions = [ + Session(id="oldest", created_at=base_time - timedelta(hours=3)), + Session(id="middle", created_at=base_time - timedelta(hours=2)), + Session(id="newest", created_at=base_time - timedelta(hours=1)) + ] + + mock_redis_client.list_sessions = AsyncMock(return_value=["oldest", "middle", "newest"]) + mock_redis_client.get_session = AsyncMock(side_effect=sessions) + + # Mock docker cleanup + session_manager.docker_cleanup_callback = AsyncMock() + + with patch.object(session_manager, 'cleanup_session', new_callable=AsyncMock) as mock_cleanup: + await session_manager._cleanup_oldest_sessions(2) + + # Two oldest sessions should be cleaned up + assert mock_cleanup.call_count == 2 + mock_cleanup.assert_any_call("oldest") + mock_cleanup.assert_any_call("middle") + + def test_set_docker_cleanup_callback(self, session_manager): + """Test setting Docker cleanup callback""" + callback = AsyncMock() + session_manager.set_docker_cleanup_callback(callback) + + assert session_manager.docker_cleanup_callback == callback \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..5b6d474 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,283 @@ +""" +Tests for utility functions +""" +import os +import tempfile +import pytest +from pathlib import Path +from unittest.mock import patch +import importlib.util +import sys +from pathlib import Path + +# Import the utils.py file directly +utils_spec = importlib.util.spec_from_file_location("utils", Path(__file__).parent.parent / "src" / "utils.py") +utils_module = importlib.util.module_from_spec(utils_spec) +sys.modules["utils"] = utils_module +utils_spec.loader.exec_module(utils_module) + +detect_project_language = utils_module.detect_project_language +calculate_loc = utils_module.calculate_loc + + +class TestDetectProjectLanguage: + """Test project language detection""" + + def test_detect_python_project(self): + """Test detecting Python project""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create Python files + Path(tmpdir, "main.py").touch() + Path(tmpdir, "utils.py").touch() + Path(tmpdir, "requirements.txt").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "python" in languages + + def test_detect_java_project(self): + """Test detecting Java project""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create Java files + Path(tmpdir, "Main.java").touch() + Path(tmpdir, "Utils.java").touch() + Path(tmpdir, "pom.xml").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "java" in languages + + def test_detect_javascript_project(self): + """Test detecting JavaScript project""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create JS files + Path(tmpdir, "app.js").touch() + Path(tmpdir, "package.json").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "javascript" in languages + + def test_detect_c_project(self): + """Test detecting C project""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create C files + Path(tmpdir, "main.c").touch() + Path(tmpdir, "utils.h").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "c" in languages + + def test_detect_cpp_project(self): + """Test detecting C++ project""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create C++ files + Path(tmpdir, "main.cpp").touch() + Path(tmpdir, "utils.hpp").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "cpp" in languages + + def test_detect_go_project(self): + """Test detecting Go project""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create Go files + Path(tmpdir, "main.go").touch() + Path(tmpdir, "go.mod").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "go" in languages + + def test_detect_kotlin_project(self): + """Test detecting Kotlin project""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create Kotlin files + Path(tmpdir, "Main.kt").touch() + Path(tmpdir, "Utils.kts").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "kotlin" in languages + + def test_detect_csharp_project(self): + """Test detecting C# project""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create C# files + Path(tmpdir, "Program.cs").touch() + Path(tmpdir, "Utils.cs").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "csharp" in languages + + def test_detect_multiple_languages(self): + """Test detecting multiple languages in one project""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create files for multiple languages + Path(tmpdir, "main.py").touch() + Path(tmpdir, "Main.java").touch() + Path(tmpdir, "app.js").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "python" in languages + assert "java" in languages + assert "javascript" in languages + + def test_detect_unknown_language(self): + """Test detecting unknown language (no recognized files)""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create unrecognized files + Path(tmpdir, "README.md").touch() + Path(tmpdir, "Dockerfile").touch() + + languages = detect_project_language(Path(tmpdir)) + assert languages == ["unknown"] + + def test_detect_empty_directory(self): + """Test detecting language in empty directory""" + with tempfile.TemporaryDirectory() as tmpdir: + languages = detect_project_language(Path(tmpdir)) + assert languages == ["unknown"] + + def test_detect_nested_files(self): + """Test detecting languages in nested directories""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create nested structure + src_dir = Path(tmpdir, "src") + src_dir.mkdir() + Path(src_dir, "main.py").touch() + Path(src_dir, "utils.py").touch() + + test_dir = Path(tmpdir, "tests") + test_dir.mkdir() + Path(test_dir, "test_main.py").touch() + + languages = detect_project_language(Path(tmpdir)) + assert "python" in languages + + +class TestCalculateLoc: + """Test lines of code calculation""" + + def test_calculate_loc_python(self): + """Test LOC calculation for Python files""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create Python files with known line counts + py_file = Path(tmpdir, "test.py") + py_file.write_text("""# Comment line +import os +import sys + +def hello(): + print("Hello World") + return True + +if __name__ == "__main__": + hello() +""") + + loc = calculate_loc(Path(tmpdir), ["python"]) + assert loc == 8 # Count of non-empty lines + + def test_calculate_loc_java(self): + """Test LOC calculation for Java files""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create Java file + java_file = Path(tmpdir, "Test.java") + java_file.write_text("""public class Test { + public static void main(String[] args) { + System.out.println("Hello World"); + } +} +""") + + loc = calculate_loc(Path(tmpdir), ["java"]) + assert loc == 5 + + def test_calculate_loc_multiple_files(self): + """Test LOC calculation across multiple files""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create multiple Python files + file1 = Path(tmpdir, "file1.py") + file1.write_text("print('hello')\nprint('world')\n") + + file2 = Path(tmpdir, "file2.py") + file2.write_text("def test():\n return True\n") + + loc = calculate_loc(Path(tmpdir), ["python"]) + assert loc == 4 + + def test_calculate_loc_mixed_languages(self): + """Test LOC calculation with multiple languages""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create Python file + py_file = Path(tmpdir, "main.py") + py_file.write_text("print('python')\n") + + # Create Java file + java_file = Path(tmpdir, "Main.java") + java_file.write_text("System.out.println('java');\n") + + # Calculate for Python only + loc_python = calculate_loc(Path(tmpdir), ["python"]) + assert loc_python == 1 + + # Calculate for Java only + loc_java = calculate_loc(Path(tmpdir), ["java"]) + assert loc_java == 1 + + # Calculate for both + loc_both = calculate_loc(Path(tmpdir), ["python", "java"]) + assert loc_both == 2 + + def test_calculate_loc_empty_file(self): + """Test LOC calculation for empty file""" + with tempfile.TemporaryDirectory() as tmpdir: + empty_file = Path(tmpdir, "empty.py") + empty_file.touch() + + loc = calculate_loc(Path(tmpdir), ["python"]) + assert loc == 0 + + def test_calculate_loc_whitespace_only(self): + """Test LOC calculation for whitespace-only file""" + with tempfile.TemporaryDirectory() as tmpdir: + ws_file = Path(tmpdir, "whitespace.py") + ws_file.write_text(" \n\t\n \n") + + loc = calculate_loc(Path(tmpdir), ["python"]) + assert loc == 0 + + def test_calculate_loc_binary_file_ignored(self): + """Test that binary/unreadable files are ignored""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a text file + text_file = Path(tmpdir, "test.py") + text_file.write_text("print('hello')\n") + + # Create a "binary" file (we'll just make it unreadable) + binary_file = Path(tmpdir, "binary.py") + binary_file.write_text("print('hello')\n") + binary_file.chmod(0o000) # Remove read permissions + + try: + loc = calculate_loc(Path(tmpdir), ["python"]) + # Should still count the readable file + assert loc == 1 + finally: + binary_file.chmod(0o644) # Restore permissions for cleanup + + def test_calculate_loc_no_matching_files(self): + """Test LOC calculation when no files match the language""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create Python file + py_file = Path(tmpdir, "test.py") + py_file.write_text("print('hello')\n") + + # Calculate for Java (no Java files) + loc = calculate_loc(Path(tmpdir), ["java"]) + assert loc == 0 + + def test_calculate_loc_unknown_language(self): + """Test LOC calculation for unknown language""" + with tempfile.TemporaryDirectory() as tmpdir: + py_file = Path(tmpdir, "test.py") + py_file.write_text("print('hello')\n") + + loc = calculate_loc(Path(tmpdir), ["unknown"]) + assert loc == 0 \ No newline at end of file diff --git a/tests/test_validators.py b/tests/test_validators.py new file mode 100644 index 0000000..ff301f2 --- /dev/null +++ b/tests/test_validators.py @@ -0,0 +1,306 @@ +""" +Tests for input validation utilities +""" +import pytest +from unittest.mock import patch +from src.utils.validators import ( + validate_source_type, + validate_language, + validate_session_id, + validate_github_url, + validate_local_path, + validate_cpgql_query, + hash_query, + sanitize_path, + validate_timeout +) +from src.exceptions import ValidationError + + +class TestValidateSourceType: + """Test source type validation""" + + def test_valid_source_types(self): + """Test valid source types""" + valid_types = ["local", "github"] + + for source_type in valid_types: + # Should not raise + validate_source_type(source_type) + + def test_invalid_source_type(self): + """Test invalid source type""" + with pytest.raises(ValidationError) as exc_info: + validate_source_type("invalid") + + assert "Invalid source_type 'invalid'" in str(exc_info.value) + assert "Must be one of: local, github" in str(exc_info.value) + + +class TestValidateLanguage: + """Test language validation""" + + def test_valid_languages(self): + """Test valid programming languages""" + valid_languages = [ + "java", "c", "cpp", "javascript", "python", "go", "kotlin", + "csharp", "ghidra", "jimple", "php", "ruby", "swift" + ] + + for language in valid_languages: + # Should not raise + validate_language(language) + + def test_invalid_language(self): + """Test invalid programming language""" + with pytest.raises(ValidationError) as exc_info: + validate_language("rust") + + assert "Unsupported language 'rust'" in str(exc_info.value) + assert "Supported:" in str(exc_info.value) + + +class TestValidateSessionId: + """Test session ID validation""" + + def test_valid_session_id(self): + """Test valid UUID session ID""" + valid_uuid = "12345678-1234-5678-9012-123456789012" + # Should not raise + validate_session_id(valid_uuid) + + def test_invalid_session_id_empty(self): + """Test empty session ID""" + with pytest.raises(ValidationError) as exc_info: + validate_session_id("") + + assert "session_id must be a non-empty string" in str(exc_info.value) + + def test_invalid_session_id_none(self): + """Test None session ID""" + with pytest.raises(ValidationError) as exc_info: + validate_session_id(None) + + assert "session_id must be a non-empty string" in str(exc_info.value) + + def test_invalid_session_id_wrong_format(self): + """Test invalid UUID format""" + invalid_ids = [ + "not-a-uuid", + "12345678-1234-5678-9012", # Too short + "12345678-1234-5678-9012-123456789012-extra", # Too long + "12345678-1234-5678-g012-123456789012" # Invalid character + ] + + for invalid_id in invalid_ids: + with pytest.raises(ValidationError) as exc_info: + validate_session_id(invalid_id) + + assert "session_id must be a valid UUID" in str(exc_info.value) + + +class TestValidateGithubUrl: + """Test GitHub URL validation""" + + def test_valid_github_urls(self): + """Test valid GitHub URLs""" + valid_urls = [ + "https://github.com/user/repo", + "https://github.com/user/repo.git", + "https://www.github.com/user/repo", + "https://github.com/user-name/repo_name", + "https://github.com/user/repo/issues" + ] + + for url in valid_urls: + # Should not raise + validate_github_url(url) + + def test_invalid_github_urls(self): + """Test invalid GitHub URLs""" + invalid_urls = [ + "https://gitlab.com/user/repo", # Wrong domain + "https://github.com/user", # Missing repo + "https://github.com/", # Incomplete + "not-a-url" + ] + + for url in invalid_urls: + with pytest.raises(ValidationError): + validate_github_url(url) + + +class TestValidateLocalPath: + """Test local path validation""" + + def test_valid_local_path(self): + """Test valid local path""" + with patch('os.path.exists', return_value=True), \ + patch('os.path.isdir', return_value=True): + # Should not raise + validate_local_path("/valid/path") + + def test_invalid_local_path_not_absolute(self): + """Test relative path""" + with pytest.raises(ValidationError) as exc_info: + validate_local_path("relative/path") + + assert "Local path must be absolute" in str(exc_info.value) + + def test_invalid_local_path_not_exists(self): + """Test non-existent path""" + with patch('os.path.exists', return_value=False): + with pytest.raises(ValidationError) as exc_info: + validate_local_path("/nonexistent/path") + + assert "Path does not exist" in str(exc_info.value) + + def test_invalid_local_path_not_directory(self): + """Test path that exists but is not a directory""" + with patch('os.path.exists', return_value=True), \ + patch('os.path.isdir', return_value=False): + with pytest.raises(ValidationError) as exc_info: + validate_local_path("/path/to/file.txt") + + assert "Path is not a directory" in str(exc_info.value) + + +class TestValidateCpgqlQuery: + """Test CPGQL query validation""" + + def test_valid_queries(self): + """Test valid CPGQL queries""" + valid_queries = [ + "cpg.method.name.l", + "cpg.call.name('printf').l", + "cpg.literal.code('SELECT *').l", + "cpg.file.name.toJson", + "cpg.method.where(_.name('main')).l" + ] + + for query in valid_queries: + # Should not raise + validate_cpgql_query(query) + + def test_empty_query(self): + """Test empty query""" + with pytest.raises(ValidationError) as exc_info: + validate_cpgql_query("") + + assert "Query must be a non-empty string" in str(exc_info.value) + + def test_none_query(self): + """Test None query""" + with pytest.raises(ValidationError) as exc_info: + validate_cpgql_query(None) + + assert "Query must be a non-empty string" in str(exc_info.value) + + def test_query_too_long(self): + """Test query that exceeds length limit""" + long_query = "cpg.method.name.l" * 1000 # Very long query + + with pytest.raises(ValidationError) as exc_info: + validate_cpgql_query(long_query) + + assert "Query too long" in str(exc_info.value) + + def test_dangerous_queries(self): + """Test queries with dangerous operations""" + dangerous_queries = [ + "System.exit(0)", + "Runtime.getRuntime.exec('rm -rf /')", + "ProcessBuilder", + "java.io.File.delete" + ] + + for query in dangerous_queries: + with pytest.raises(ValidationError) as exc_info: + validate_cpgql_query(query) + + assert "potentially dangerous operation" in str(exc_info.value) + + +class TestHashQuery: + """Test query hashing""" + + def test_hash_query_consistent(self): + """Test that same query produces same hash""" + query = "cpg.method.name.l" + hash1 = hash_query(query) + hash2 = hash_query(query) + + assert hash1 == hash2 + assert isinstance(hash1, str) + assert len(hash1) == 64 # SHA256 hex length + + def test_hash_query_different(self): + """Test that different queries produce different hashes""" + query1 = "cpg.method.name.l" + query2 = "cpg.call.name.l" + + hash1 = hash_query(query1) + hash2 = hash_query(query2) + + assert hash1 != hash2 + + +class TestSanitizePath: + """Test path sanitization""" + + def test_sanitize_path_safe(self): + """Test sanitizing safe paths""" + safe_paths = [ + "/safe/path", + "/another/safe/path", + "safe/path" + ] + + for path in safe_paths: + result = sanitize_path(path) + assert result == path + + def test_sanitize_path_traversal(self): + """Test sanitizing paths with traversal attempts""" + dangerous_paths = [ + "../../../etc/passwd", + "../../../../root", + "..\\..\\..\\windows\\system32" + ] + + for path in dangerous_paths: + result = sanitize_path(path) + assert ".." not in result + + +class TestValidateTimeout: + """Test timeout validation""" + + def test_valid_timeout(self): + """Test valid timeout values""" + valid_timeouts = [1, 30, 300, 100] + + for timeout in valid_timeouts: + # Should not raise + validate_timeout(timeout) + + def test_invalid_timeout_zero(self): + """Test zero timeout""" + with pytest.raises(ValidationError) as exc_info: + validate_timeout(0) + + assert "Timeout must be at least 1 second" in str(exc_info.value) + + def test_invalid_timeout_negative(self): + """Test negative timeout""" + with pytest.raises(ValidationError) as exc_info: + validate_timeout(-1) + + assert "Timeout must be at least 1 second" in str(exc_info.value) + + def test_invalid_timeout_too_large(self): + """Test timeout exceeding maximum""" + with pytest.raises(ValidationError) as exc_info: + validate_timeout(400) + + assert "Timeout cannot exceed 300 seconds" in str(exc_info.value) \ No newline at end of file