Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ jobs:
- name: Run tests
run: |
make test
- name: Run test example
run: |
make test-example
- name: Run test example parallel (xdist)
run: |
make test-example-parallel
- name: Run integration tests
run: |
Expand Down
13 changes: 9 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@ version:
@uv version

ruff:
@echo "Running ruff..."
@uv run ruff format .
@uv run ruff check .
@echo "Running ruff format on src, tests, and example..."
@uv run ruff format src tests example
@echo "Running ruff check on src"
@uv run ruff check src

mypy:
@echo "Running mypy..."
@uv run mypy

format: ruff mypy
vulture:
@echo "Running vulture..."
@uv run vulture

format: ruff mypy vulture

test:
@echo "Running plugin tests..."
Expand Down
24 changes: 24 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dev = [
"pytest-xdist>=3.8.0",
"ruff>=0.12.3",
"typeguard>=4.4.4",
"vulture>=2.14",
]

# API COVERAGE
Expand Down Expand Up @@ -134,6 +135,22 @@ lint.ignore = [
"TD003",
"FIX002",
"PLC0415",
"PLR0912",
"PLR0915",
"C901",
# Print statements are fine for CLI tools
"T201",
# Any types are common in plugin/interop code
"ANN401",
# Exception patterns
"EM102",
# Try-except patterns
"S110",
"S112",
# Magic numbers
"PLR2004",
# Private member access
"SLF001",
]

[tool.ruff.lint.per-file-ignores]
Expand All @@ -157,3 +174,10 @@ pretty = true
show_column_numbers = true
show_error_codes = true
show_error_context = true

[tool.vulture]
exclude = []
min_confidence = 80
paths = ["src", "tests", "example"]
sort_by_size = true
verbose = false
2 changes: 2 additions & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""Package root."""

# This file makes the src directory a Python package
38 changes: 17 additions & 21 deletions src/pytest_api_cov/cli.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,42 @@
"""CLI commands for setup and configuration."""

import argparse
import os
import sys
from pathlib import Path
from typing import Optional, Tuple


def detect_framework_and_app() -> Optional[Tuple[str, str, str]]:
"""Detect framework and app location.

Returns (framework, file_path, app_variable) or None.
"""
import glob

# Search for app files at any depth in current directory
app_patterns = ["app.py", "main.py", "server.py", "wsgi.py", "asgi.py"]
common_vars = ["app", "application", "main", "server"]

# Find all matching files recursively
found_files = []
for pattern in app_patterns:
found_files.extend(glob.glob(f"**/{pattern}", recursive=True))
found_files = [file_path for pattern in app_patterns for file_path in Path().rglob(pattern)]

# Sort by depth (shallowest first) and then by filename priority
found_files.sort(key=lambda x: (x.count(os.sep), app_patterns.index(os.path.basename(x))))
found_files.sort(key=lambda p: (len(p.parts), app_patterns.index(p.name)))

for file_path in found_files:
try:
with open(file_path, "r") as f:
content = f.read()
content = file_path.read_text()

if "from fastapi import" in content or "import fastapi" in content:
framework = "FastAPI"
elif "from flask import" in content or "import flask" in content:
framework = "Flask"
else:
continue
continue # Not a framework file we care about

for var_name in common_vars:
if f"{var_name} = " in content:
return framework, file_path, var_name
return framework, file_path.as_posix(), var_name

except Exception:
except (IOError, UnicodeDecodeError):
continue

return None
Expand Down Expand Up @@ -88,7 +84,7 @@ def app():
'''


def generate_pyproject_config(framework: str) -> str:
def generate_pyproject_config() -> str:
"""Generate pyproject.toml configuration section."""
return """
# pytest-api-cov configuration
Expand Down Expand Up @@ -133,7 +129,7 @@ def cmd_init() -> int:
framework, file_path, app_variable = detection_result
print(f"✅ Detected {framework} app in {file_path} (variable: {app_variable})")

conftest_exists = os.path.exists("conftest.py")
conftest_exists = Path("conftest.py").exists()
if conftest_exists:
print("⚠️ conftest.py already exists")
create_conftest = input("Do you want to overwrite it? (y/N): ").lower().startswith("y")
Expand All @@ -142,28 +138,28 @@ def cmd_init() -> int:

if create_conftest:
conftest_content = generate_conftest_content(framework, file_path, app_variable)
with open("conftest.py", "w") as f:
with Path("conftest.py").open("w") as f:
f.write(conftest_content)
print("✅ Created conftest.py")

pyproject_exists = os.path.exists("pyproject.toml")
pyproject_exists = Path("pyproject.toml").exists()
if pyproject_exists:
print("ℹ️ pyproject.toml already exists")
print("ℹ️ pyproject.toml already exists") # noqa: RUF001
print("Add this configuration to your pyproject.toml:")
print(generate_pyproject_config(framework))
print(generate_pyproject_config())
else:
create_pyproject = input("Create pyproject.toml with pytest-api-cov config? (Y/n): ").lower()
if not create_pyproject.startswith("n"):
pyproject_content = f"""[project]
name = "your-project"
version = "0.1.0"

{generate_pyproject_config(framework)}
{generate_pyproject_config()}

[tool.pytest.ini_options]
testpaths = ["tests"]
"""
with open("pyproject.toml", "w") as f:
with Path("pyproject.toml").open("w") as f:
f.write(pyproject_content)
print("✅ Created pyproject.toml")

Expand Down Expand Up @@ -205,7 +201,7 @@ def read_root():


def main() -> int:
"""Main CLI entry point."""
"""Run the main CLI entry point."""
parser = argparse.ArgumentParser(prog="pytest-api-cov", description="pytest API coverage plugin CLI tools")

subparsers = parser.add_subparsers(dest="command", help="Available commands")
Expand Down
20 changes: 11 additions & 9 deletions src/pytest_api_cov/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Configuration handling for the API coverage report."""

import sys
from pathlib import Path
from typing import Any, Dict, List, Optional

import tomli
Expand All @@ -13,21 +14,21 @@ class ApiCoverageReportConfig(BaseModel):
model_config = ConfigDict(populate_by_name=True)

fail_under: Optional[float] = Field(None, alias="api-cov-fail-under")
show_uncovered_endpoints: bool = Field(True, alias="api-cov-show-uncovered-endpoints")
show_covered_endpoints: bool = Field(False, alias="api-cov-show-covered-endpoints")
show_excluded_endpoints: bool = Field(False, alias="api-cov-show-excluded-endpoints")
exclusion_patterns: List[str] = Field([], alias="api-cov-exclusion-patterns")
show_uncovered_endpoints: bool = Field(default=True, alias="api-cov-show-uncovered-endpoints")
show_covered_endpoints: bool = Field(default=False, alias="api-cov-show-covered-endpoints")
show_excluded_endpoints: bool = Field(default=False, alias="api-cov-show-excluded-endpoints")
exclusion_patterns: List[str] = Field(default=[], alias="api-cov-exclusion-patterns")
report_path: Optional[str] = Field(None, alias="api-cov-report-path")
force_sugar: bool = Field(False, alias="api-cov-force-sugar")
force_sugar_disabled: bool = Field(False, alias="api-cov-force-sugar-disabled")
force_sugar: bool = Field(default=False, alias="api-cov-force-sugar")
force_sugar_disabled: bool = Field(default=False, alias="api-cov-force-sugar-disabled")
client_fixture_name: str = Field("coverage_client", alias="api-cov-client-fixture-name")
group_methods_by_endpoint: bool = Field(False, alias="api-cov-group-methods-by-endpoint")
group_methods_by_endpoint: bool = Field(default=False, alias="api-cov-group-methods-by-endpoint")


def read_toml_config() -> Dict[str, Any]:
"""Read the [tool.pytest_api_cov] section from pyproject.toml."""
try:
with open("pyproject.toml", "rb") as f:
with Path("pyproject.toml").open("rb") as f:
toml_config = tomli.load(f)
return toml_config.get("tool", {}).get("pytest_api_cov", {}) # type: ignore[no-any-return]
except (FileNotFoundError, tomli.TOMLDecodeError):
Expand Down Expand Up @@ -65,7 +66,8 @@ def supports_unicode() -> bool:

def get_pytest_api_cov_report_config(session_config: Any) -> ApiCoverageReportConfig:
"""Get the final API coverage configuration by merging sources.
Priority: CLI > pyproject.toml > Defaults

Priority: CLI > pyproject.toml > Defaults.
"""
toml_config = read_toml_config()
cli_config = read_session_config(session_config)
Expand Down
42 changes: 26 additions & 16 deletions src/pytest_api_cov/frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@


class BaseAdapter:
def __init__(self, app: Any):
"""Base adapter for framework applications."""

def __init__(self, app: Any) -> None:
"""Initialize the adapter."""
self.app = app

def get_endpoints(self) -> List[str]:
Expand All @@ -20,20 +23,23 @@ def get_tracked_client(self, recorder: Optional["ApiCallRecorder"], test_name: s


class FlaskAdapter(BaseAdapter):
"""Adapter for Flask applications."""

def get_endpoints(self) -> List[str]:
"""Return list of 'METHOD /path' strings."""
excluded_rules = ("/static/<path:filename>",)
endpoints = []

for rule in self.app.url_map.iter_rules():
if rule.rule not in excluded_rules:
for method in rule.methods:
if method not in ("HEAD", "OPTIONS"): # Skip automatic methods
endpoints.append(f"{method} {rule.rule}")
endpoints = [
f"{method} {rule.rule}"
for rule in self.app.url_map.iter_rules()
if rule.rule not in excluded_rules
for method in rule.methods
if method not in ("HEAD", "OPTIONS")
]

return sorted(endpoints)

def get_tracked_client(self, recorder: Optional["ApiCallRecorder"], test_name: str) -> Any:
"""Return a patched test client that records calls."""
from flask.testing import FlaskClient

if recorder is None:
Expand All @@ -49,28 +55,32 @@ def open(self, *args: Any, **kwargs: Any) -> Any:
endpoint_name, _ = self.application.url_map.bind("").match(path, method=method)
endpoint_rule_string = next(self.application.url_map.iter_rules(endpoint_name)).rule
recorder.record_call(endpoint_rule_string, test_name, method) # type: ignore[union-attr]
except Exception:
except Exception: # noqa: BLE001
pass
return super().open(*args, **kwargs)

return TrackingFlaskClient(self.app, self.app.response_class)


class FastAPIAdapter(BaseAdapter):
"""Adapter for FastAPI applications."""

def get_endpoints(self) -> List[str]:
"""Return list of 'METHOD /path' strings."""
from fastapi.routing import APIRoute

endpoints = []
for route in self.app.routes:
if isinstance(route, APIRoute):
for method in route.methods:
if method not in ("HEAD", "OPTIONS"):
endpoints.append(f"{method} {route.path}")
endpoints = [
f"{method} {route.path}"
for route in self.app.routes
if isinstance(route, APIRoute)
for method in route.methods
if method not in ("HEAD", "OPTIONS")
]

return sorted(endpoints)

def get_tracked_client(self, recorder: Optional["ApiCallRecorder"], test_name: str) -> Any:
"""Return a patched test client that records calls."""
from starlette.testclient import TestClient

if recorder is None:
Expand All @@ -89,7 +99,7 @@ def send(self, *args: Any, **kwargs: Any) -> Any:


def get_framework_adapter(app: Any) -> BaseAdapter:
"""Detects the framework and returns the appropriate adapter."""
"""Detect the framework and return the appropriate adapter."""
app_type = type(app).__name__
module_name = getattr(type(app), "__module__", "").split(".")[0]

Expand Down
4 changes: 0 additions & 4 deletions src/pytest_api_cov/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ def __len__(self) -> int:
"""Return number of discovered endpoints."""
return len(self.endpoints)

def __iter__(self) -> Iterable[str]: # type: ignore[override]
"""Iterate over discovered endpoints."""
return iter(self.endpoints)


class SessionData(BaseModel):
"""Model for session-level API coverage data."""
Expand Down
Loading
Loading